[
  {
    "path": ".claude/skills/checkpoints/SKILL.md",
    "content": "---\nname: checkpoints\ndescription: Guide for checkpointing — saving, loading, and resuming training with CheckpointRecord. Use when the user asks about saving weights, resuming training, checkpoint management, or the checkpoint lifecycle.\n---\n\n# Checkpointing\n\nTinker supports two types of checkpoints and provides utilities for managing them during training.\n\n## Reference\n\nRead these for details:\n- `tinker_cookbook/checkpoint_utils.py` — CheckpointRecord, save/load helpers\n- `docs/save-load.mdx` — Checkpointing guide (save_weights_for_sampler vs save_state)\n\n## Two checkpoint types\n\n| Type | Method | Purpose | Contains |\n|------|--------|---------|----------|\n| **State** | `save_state()` | Resume training | Weights + optimizer state |\n| **Sampler** | `save_weights_for_sampler()` | Sampling / export | Weights only |\n\n```python\n# Save full state (for resumption)\ntc.save_state(name=\"step_100\", ttl_seconds=None)\n\n# Save sampler weights (for sampling/export)\ntc.save_weights_for_sampler(name=\"step_100_sampler\", ttl_seconds=None)\n\n# Save both + get a SamplingClient\nsc = tc.save_weights_and_get_sampling_client(name=\"step_100\")\n```\n\n`ttl_seconds=None` means indefinite retention. Set a TTL for intermediate checkpoints to avoid storage bloat.\n\n## CheckpointRecord\n\nTyped dataclass for checkpoint bookkeeping:\n\n```python\nfrom tinker_cookbook.checkpoint_utils import CheckpointRecord\n\nrecord = CheckpointRecord(\n    name=\"step_100\",\n    batch=100,\n    epoch=1,\n    final=False,\n    state_path=\"tinker://...\",\n    sampler_path=\"tinker://...\",\n    extra={\"eval_loss\": 0.5},  # User metadata\n)\n\n# Serialize\nd = record.to_dict()\n\n# Deserialize\nrecord = CheckpointRecord.from_dict(d)\n\n# Check if a field is set\nrecord.has(\"state_path\")  # True\n```\n\n## Save/load helpers\n\n```python\nfrom tinker_cookbook import checkpoint_utils\n\n# Save checkpoint (async)\npaths = await checkpoint_utils.save_checkpoint_async(\n    training_client=tc,\n    name=\"step_100\",\n    log_path=\"/tmp/my_run\",\n    loop_state={\"batch\": 100, \"epoch\": 1},\n    kind=\"both\",           # \"state\", \"sampler\", or \"both\"\n    ttl_seconds=None,\n)\n# paths = {\"state_path\": \"tinker://...\", \"sampler_path\": \"tinker://...\"}\n\n# Load checkpoint list\nrecords = checkpoint_utils.load_checkpoints_file(\"/tmp/my_run\")\n\n# Get last checkpoint\nrecord = checkpoint_utils.get_last_checkpoint(\n    \"/tmp/my_run\",\n    required_key=\"state_path\",  # Only return records with this field\n)\n```\n\n## Resuming training\n\nThe standard pattern (used by `supervised/train.py` and `rl/train.py`):\n\n```python\n# In CLIConfig\nbehavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"  # \"ask\", \"delete\", \"resume\"\n\n# In training loop\nif config.load_checkpoint_path:\n    tc.load_state_with_optimizer(config.load_checkpoint_path)\n```\n\nSet `behavior_if_log_dir_exists=resume` to continue from the last checkpoint in an existing log directory.\n\n## Managing checkpoints (REST API / CLI)\n\nBeyond saving and loading during training, you can manage checkpoints via the REST API or CLI. See `/tinker-sdk` for RestClient details and `/tinker-cli` for CLI commands.\n\n```python\nfrom tinker import ServiceClient\nrest = ServiceClient().create_rest_client()\n\n# List all your checkpoints\ncheckpoints = rest.list_user_checkpoints(limit=100)\n\n# Publish a checkpoint (make it publicly accessible)\nrest.publish_checkpoint_from_tinker_path(\"tinker://...\")\n\n# Set TTL (auto-delete after N seconds)\nrest.set_checkpoint_ttl_from_tinker_path(\"tinker://...\", ttl_seconds=86400)\n\n# Delete a checkpoint\nrest.delete_checkpoint_from_tinker_path(\"tinker://...\")\n```\n\nOr via CLI:\n```bash\ntinker checkpoint list\ntinker checkpoint publish <TINKER_PATH>\ntinker checkpoint set-ttl <TINKER_PATH> --ttl 86400\ntinker checkpoint delete <TINKER_PATH>\n```\n\n## Common pitfalls\n- Use `save_state` for resumable checkpoints, `save_weights_for_sampler` for sampling/export\n- `get_last_checkpoint()` returns `None` if no matching checkpoint exists — always check\n- Checkpoint paths start with `tinker://` — they reference remote storage, not local files\n- Set `ttl_seconds` on intermediate checkpoints to avoid accumulating old weights\n- For RLHF pipelines, the SFT stage saves `state_path` (for RL init) and the RM stage saves `sampler_path` (for reward scoring)\n- `delete` is permanent — there is no undo\n"
  },
  {
    "path": ".claude/skills/ci/SKILL.md",
    "content": "---\nname: ci\ndescription: Guide for testing conventions and CI pipelines — unit tests, integration smoke tests, pytest markers, and GitHub Actions workflows. Use when the user asks about testing, CI, running tests, or adding tests for a recipe.\n---\n\n# Testing & CI\n\nThe repo has two layers of testing and two CI workflows.\n\n## Reference\n\nRead these for details:\n- `tests/helpers.py` — `run_recipe()` helper for smoke tests\n- `tests/conftest.py` — Pytest configuration and API key handling\n- `tests/recipes/` — Existing recipe smoke tests\n- `.github/workflows/pytest.yaml` — Unit test CI (every PR)\n- `.github/workflows/smoke-test-recipes.yaml` — Smoke test CI (daily)\n- `CONTRIBUTING.md` — Development setup and test commands\n- `pyproject.toml` — Pytest configuration (testpaths, markers, file patterns)\n\n## Test structure\n\n```\ntinker-cookbook/\n├── tinker_cookbook/\n│   ├── renderers/parsing_test.py     # Unit tests: *_test.py next to source\n│   ├── recipes/math_rl/math_env_test.py\n│   └── ...\n└── tests/\n    ├── conftest.py                   # Skips integration tests without API key\n    ├── helpers.py                    # run_recipe() helper\n    └── recipes/\n        ├── test_recipe_chat_sl.py    # Integration tests: test_recipe_*.py\n        ├── test_recipe_dpo.py\n        └── ...\n```\n\n## Unit tests (`*_test.py`)\n\nColocated with source code. Run without API key.\n\n```bash\nuv run pytest tinker_cookbook/\n```\n\n**Conventions:**\n- File naming: `<module>_test.py` next to the code it tests\n- No network calls, no `TINKER_API_KEY` required\n- Fast (< 1s per test)\n- Use standard pytest features (fixtures, parametrize, marks)\n- Test picklability for components used in distributed rollout\n\n**Example:** `tinker_cookbook/renderers/parsing_test.py`\n\n## Integration / smoke tests (`test_recipe_*.py`)\n\nLive in `tests/recipes/`. Require `TINKER_API_KEY`. Verify recipes can run.\n\n```bash\n# Run all integration tests\nuv run pytest tests/ -v -x -s\n\n# Run a specific recipe test\nuv run pytest tests/recipes/test_recipe_chat_sl.py -v -x -s\n```\n\n**Conventions:**\n- File naming: `tests/recipes/test_recipe_<name>.py`\n- Mark with `@pytest.mark.integration`\n- Use `run_recipe()` from `tests/helpers.py`\n- `run_recipe()` passes `max_steps=2` by default — recipe runs 2 training steps and exits\n- Always pass `behavior_if_log_dir_exists=delete` to avoid CI conflicts\n- Override batch sizes to small values for fast execution\n\n**Template:**\n\n```python\nimport pytest\nfrom tests.helpers import run_recipe\n\n@pytest.mark.integration\ndef test_my_recipe():\n    run_recipe(\n        \"tinker_cookbook.recipes.my_recipe.train\",\n        [\n            \"behavior_if_log_dir_exists=delete\",\n            \"groups_per_batch=4\",\n        ],\n    )\n```\n\n### How `run_recipe()` works\n1. Launches `uv run python -m <module> <args> max_steps=2` as a subprocess\n2. Streams stdout in real time for CI debuggability\n3. Waits for clean exit (exit code 0) within timeout (default: 1800s)\n4. Fails if process exits non-zero or times out\n\n## Pytest markers\n\nDefined in `pyproject.toml`:\n- `@pytest.mark.integration` — Requires API key, skipped locally without `TINKER_API_KEY`\n- `@pytest.mark.slow` — Long-running tests\n\n`tests/conftest.py` auto-skips integration tests when `TINKER_API_KEY` is not set (fails on CI if missing).\n\n## CI workflows\n\n### `pytest.yaml` — Unit tests (every PR/push to main)\n```\nTrigger: push to main, pull requests\nRuns: uv run pytest tinker_cookbook/\nRequires: HF_TOKEN (for tokenizer access)\n```\n\n### `smoke-test-recipes.yaml` — Integration tests (daily + manual)\n```\nTrigger: daily at 6am UTC, manual dispatch\nRuns: Each test_recipe_*.py in parallel (matrix strategy)\nRequires: TINKER_API_KEY, HF_TOKEN\nTimeout: 20 min per recipe\nConcurrency: 1 (avoid API contention)\n```\n\nAdding `tests/recipes/test_recipe_<name>.py` is all that's needed — CI auto-discovers it.\n\n## Running pre-commit checks\n\n```bash\nuv run ruff check tinker_cookbook/\nuv run ruff format tinker_cookbook/\nuv run pyright tinker_cookbook/\npre-commit run --all-files\n```\n"
  },
  {
    "path": ".claude/skills/completers/SKILL.md",
    "content": "---\nname: completers\ndescription: Guide for using completers — TokenCompleter and MessageCompleter for text generation during RL rollouts and evaluation. Use when the user asks about generating text, completing messages, or using completers in RL environments.\n---\n\n# Completers\n\nCompleters wrap SamplingClient for convenient text generation. Two levels of abstraction:\n- **TokenCompleter** — low-level, returns tokens + logprobs\n- **MessageCompleter** — high-level, returns parsed Message objects\n\n## Reference\n\nRead these for details:\n- `tinker_cookbook/completers.py` — Implementation\n- `docs/completers.mdx` — Usage guide\n\n## TokenCompleter\n\nGenerates tokens from a ModelInput prompt. Used internally by RL rollouts.\n\n```python\nfrom tinker_cookbook.completers import TinkerTokenCompleter, TokensWithLogprobs\n\ncompleter = TinkerTokenCompleter(\n    sampling_client=sc,\n    max_tokens=256,\n    temperature=1.0,\n)\n\nresult: TokensWithLogprobs = await completer(\n    model_input=prompt,\n    stop=stop_sequences,  # list[str] or list[int]\n)\n# result.tokens: list[int]\n# result.maybe_logprobs: list[float] | None\n```\n\n## MessageCompleter\n\nHigher-level: takes a conversation (list of Messages), returns a Message. Handles rendering and parsing internally.\n\n```python\nfrom tinker_cookbook.completers import TinkerMessageCompleter\n\ncompleter = TinkerMessageCompleter(\n    sampling_client=sc,\n    renderer=renderer,\n    max_tokens=256,\n    temperature=1.0,\n    stop_condition=None,  # Override stop sequences\n)\n\nresponse_message: Message = await completer(messages=[\n    {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n])\n# response_message = {\"role\": \"assistant\", \"content\": \"4\"}\n```\n\n## When to use which\n\n- **TokenCompleter**: RL rollouts, custom generation loops where you need logprobs and token-level control\n- **MessageCompleter**: Evaluation, tool-use environments, multi-turn RL where you work with Messages\n\n## Custom completers\n\nBoth are abstract base classes you can subclass for non-Tinker backends:\n\n```python\nfrom tinker_cookbook.completers import TokenCompleter, MessageCompleter\n\nclass MyTokenCompleter(TokenCompleter):\n    async def __call__(self, model_input, stop) -> TokensWithLogprobs:\n        ...\n\nclass MyMessageCompleter(MessageCompleter):\n    async def __call__(self, messages) -> Message:\n        ...\n```\n\n## Common pitfalls\n- Create a new completer (with a new SamplingClient) after saving weights\n- `TokensWithLogprobs.maybe_logprobs` can be `None` if logprobs weren't requested\n- MessageCompleter uses the renderer for both prompt construction and response parsing\n"
  },
  {
    "path": ".claude/skills/contributing/SKILL.md",
    "content": "---\nname: contributing\ndescription: Guide for contributing to the tinker-cookbook repo — development setup, code style, type checking, PR process, and design conventions. Use when the user asks about how to contribute, set up the dev environment, code style, or project conventions.\n---\n\n# Contributing\n\nGuide for developing and contributing to tinker-cookbook.\n\n## Reference\n\nRead `CONTRIBUTING.md` for the full guide.\n\n## Development setup\n\n```bash\ngit clone https://github.com/thinking-machines-lab/tinker-cookbook.git\ncd tinker-cookbook\nuv sync --extra dev\npre-commit install\n```\n\nThis installs dev dependencies and registers pre-commit hooks (`ruff` formatting/linting).\n\n## Code style\n\n- **Formatter/Linter:** [ruff](https://docs.astral.sh/ruff/) (line length: 100)\n- **Type checker:** [pyright](https://github.com/microsoft/pyright)\n- **Pre-commit hooks** run automatically on every commit\n\n```bash\nuv run ruff check tinker_cookbook/\nuv run ruff format tinker_cookbook/\nuv run pyright tinker_cookbook/\n```\n\n### Typing rules\n- Use explicit types everywhere\n- Avoid `Any` and `type: ignore` — prefer casting\n- Prefer single types over union types\n- Don't add convoluted generics just to satisfy the type checker\n\n## Design conventions\n\n### Builder pattern\nConfig objects build runtime objects:\n- `SupervisedDatasetBuilder` → `SupervisedDataset`\n- `RLDatasetBuilder` → `RLDataset`\n- `EnvGroupBuilder` → group of `Env` objects\n\nConfig objects use `@chz.chz` decorator. They have a `__call__` method that builds the runtime object.\n\n### Config/runtime separation\n- **Config:** `@chz.chz` dataclasses, serializable, lightweight\n- **Runtime:** Regular classes or dataclasses, heavyweight (datasets, clients)\n\n### Training script organization\n- **`tinker_cookbook/<module>/train.py`** — Main training loop with detailed `Config` (not CLI-constructable)\n- **`tinker_cookbook/recipes/<name>/train.py`** — Launch script with `CLIConfig` from command line\n\n### Async\n- All methods that take nontrivial time should be async (especially in RL)\n- Some beginner-oriented code (e.g., `sl_loop.py`) uses sync for simplicity\n\n### Env lifecycle\n- `Env` objects are single-use (no reset)\n- Shared resources managed by `EnvGroupBuilder`, not individual `Env`s\n\n### Dimension notation\nSubscript suffixes on variable names:\n- `_P` = problems, `_G` = groups, `_T` = tokens, `_D` = datums\n- Example: `tokens_P_G_T[p][g][t]` = token t of group g of problem p\n- Flattened: `tokens_PG_T` = problems and groups merged into one dimension\n\n## PR process\n\n1. Create a feature branch from `main`\n2. Make changes with tests\n3. Run `pre-commit run --all-files`\n4. Open PR with clear description\n\nCI runs pre-commit, pyright, and pytest on every PR.\n\n## Testing\n\nSee the `/ci` skill for full testing details.\n\n```bash\n# Unit tests (no API key needed)\nuv run pytest tinker_cookbook/\n\n# Integration tests (requires TINKER_API_KEY)\nuv run pytest tests/\n```\n"
  },
  {
    "path": ".claude/skills/datasets/SKILL.md",
    "content": "---\nname: datasets\ndescription: Guide for dataset construction — SupervisedDatasetBuilder, RLDatasetBuilder, ChatDatasetBuilder, and custom dataset creation from JSONL, HuggingFace, or conversation data. Use when the user asks about datasets, data loading, data preparation, or custom data formats.\n---\n\n# Datasets\n\nThe cookbook uses the builder pattern for datasets: a `*DatasetBuilder` (config) builds a `*Dataset` (runtime).\n\n## Reference\n\nRead these for details:\n- `tinker_cookbook/supervised/types.py` — SupervisedDatasetBuilder, ChatDatasetBuilder, ChatDatasetBuilderCommonConfig\n- `tinker_cookbook/supervised/data.py` — Dataset construction helpers, FromConversationFileBuilder\n- `tinker_cookbook/rl/types.py` — RLDatasetBuilder, RLDataset\n- `docs/training-sampling.mdx` — Data preparation basics\n\n## Supervised datasets\n\n### ChatDatasetBuilderCommonConfig\n\nShared config for all chat-based dataset builders:\n\n```python\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilderCommonConfig\nfrom tinker_cookbook.renderers import TrainOnWhat\n\ncommon_config = ChatDatasetBuilderCommonConfig(\n    model_name_for_tokenizer=\"meta-llama/Llama-3.1-8B\",\n    renderer_name=\"llama3\",\n    max_length=32768,        # Max sequence length\n    batch_size=128,          # Tokens per batch\n    train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES,\n)\n```\n\n### Built-in datasets\n\n```python\nfrom tinker_cookbook.recipes.chat_sl.chat_datasets import NoRobotsBuilder, Tulu3Builder\n\ndataset = NoRobotsBuilder(common_config=common_config)\ndataset = Tulu3Builder(common_config=common_config)\n```\n\n### Custom JSONL file\n\n```python\nfrom tinker_cookbook.supervised.data import FromConversationFileBuilder\n\ndataset = FromConversationFileBuilder(\n    common_config=common_config,\n    file_path=\"/path/to/data.jsonl\",\n    test_size=100,       # Hold out 100 examples for eval\n    shuffle_seed=42,\n)\n```\n\nJSONL format — each line is a conversation:\n```json\n{\"messages\": [{\"role\": \"user\", \"content\": \"...\"}, {\"role\": \"assistant\", \"content\": \"...\"}]}\n```\n\nSee `tinker_cookbook/example_data/conversations.jsonl` for the expected format.\n\n### From HuggingFace datasets\n\n```python\nfrom tinker_cookbook.supervised.data import SupervisedDatasetFromHFDataset\n\ndataset = SupervisedDatasetFromHFDataset(\n    hf_dataset=hf_dataset,\n    batch_size=128,\n    map_fn=lambda example: conversation_to_datum(\n        example[\"messages\"], renderer, max_length, train_on_what\n    ),\n)\n```\n\n### Low-level datum construction\n\n```python\nfrom tinker_cookbook.supervised.data import conversation_to_datum\n\n# Full pipeline: messages → datum\ndatum = conversation_to_datum(messages, renderer, max_length, train_on_what)\n\n# Or step by step:\nmodel_input, weights = renderer.build_supervised_example(messages)\ndatum = datum_from_model_input_weights(model_input, weights, max_length)\n```\n\n## RL datasets\n\nRL datasets return batches of `EnvGroupBuilder` objects. See the `/environments` skill for details.\n\n```python\n@chz.chz\nclass MyRLDatasetBuilder(RLDatasetBuilder):\n    batch_size: int = 128\n    group_size: int = 4\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset | None]:\n        # Return (train_dataset, optional_test_dataset)\n        ...\n```\n\n## DPO datasets\n\nDPO uses comparison pairs (chosen vs rejected):\n\n```python\nfrom tinker_cookbook.preference.dpo_datasets import DPODatasetBuilderFromComparisons\n\ndataset = DPODatasetBuilderFromComparisons(\n    common_config=common_config,\n    comparison_builder=HHHComparisonBuilder(),\n)\n```\n\nSee `tinker_cookbook/preference/dpo_datasets.py` and `tinker_cookbook/recipes/preference/datasets.py`.\n\n## Common pitfalls\n- Always use `ChatDatasetBuilderCommonConfig` for consistent tokenizer/renderer setup\n- `batch_size` is in tokens, not examples — larger sequences mean fewer examples per batch\n- Custom JSONL must match the format in `example_data/conversations.jsonl`\n- Use `test_size` to hold out evaluation data from the same distribution\n- Dataset builders must be serializable (`@chz.chz`) for config persistence and sweeps\n"
  },
  {
    "path": ".claude/skills/distillation/SKILL.md",
    "content": "---\nname: distillation\ndescription: Set up and run knowledge distillation (on-policy, off-policy, or multi-teacher) from a teacher model to a student model using the Tinker API. Use when the user wants to distill knowledge, compress models, or train a student from a teacher.\nargument-hint: \"[student-model] [teacher-model]\"\n---\n\n# Knowledge Distillation\n\nHelp the user set up and run distillation from teacher to student models using the Tinker API.\n\n## Step 1: Understand the request\n\nAsk the user (if not already specified):\n- **Student model**: Which model to train (e.g., `Qwen/Qwen3-8B-Base`)\n- **Teacher model**: Which model to distill from (e.g., `Qwen/Qwen3-8B`, or a checkpoint path)\n- **Distillation type**:\n  - **On-policy**: Student generates, teacher scores via KL — best for reasoning/chat\n  - **Off-policy reasoning**: SFT on teacher-generated reasoning traces (e.g., OpenThoughts3)\n  - **Multi-teacher**: Combine multiple teachers on different datasets\n\n## Step 2: Reference existing recipes\n\nRead these files for patterns:\n- `tinker_cookbook/recipes/distillation/on_policy_distillation.py` — On-policy distillation CLI\n- `tinker_cookbook/recipes/distillation/off_policy_reasoning.py` — SFT on OpenThoughts3 traces\n- `tinker_cookbook/recipes/distillation/on_policy_multi_teacher.py` — Multi-teacher setup\n- `tinker_cookbook/distillation/train_on_policy.py` — Core on-policy training loop\n- `tinker_cookbook/distillation/datasets.py` — TeacherConfig, PromptOnlyDatasetBuilder, DistillationDatasetConfig\n\n## Step 3: Choose distillation approach\n\n### On-Policy Distillation (Recommended)\nStudent generates samples, teacher provides KL penalty supervision. No correctness rewards needed.\n\nKey config:\n- `TeacherConfig(base_model=\"Qwen/Qwen3-8B\", load_checkpoint_path=None)`\n- `PromptOnlyDatasetBuilder(dataset_name=\"deepmath\"|\"tulu3\", ...)`\n- `DistillationDatasetConfig(dataset_builder=..., teacher_config=..., groups_per_batch=...)`\n- `kl_penalty_coef`: Weight of KL penalty (default 1.0)\n- `kl_discount_factor`: Discount for future KL (0.0 = no discount)\n\n### Off-Policy Reasoning (SFT on Traces)\nStandard SFT on pre-generated reasoning traces (e.g., OpenThoughts3). Simpler but less effective than on-policy.\n\nSee `recipes/distillation/off_policy_reasoning.py` — uses the standard SL pipeline from `supervised/train.py`.\n\n### Multi-Teacher Distillation\nCombine multiple teacher models on different datasets. Each dataset can have its own teacher.\n\nSee `recipes/distillation/on_policy_multi_teacher.py` — passes multiple `DistillationDatasetConfig` objects.\n\n## Step 4: Write the training script\n\nFollow the on-policy distillation pattern:\n\n```python\nimport asyncio\nimport chz\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.distillation import train_on_policy\nfrom tinker_cookbook.distillation.datasets import (\n    DistillationDatasetConfig,\n    PromptOnlyDatasetBuilder,\n    TeacherConfig,\n)\n\n@chz.chz\nclass CLIConfig:\n    model_name: str = \"Qwen/Qwen3-8B-Base\"       # Student\n    teacher_model: str = \"Qwen/Qwen3-8B\"          # Teacher\n    dataset: str = \"deepmath\"                       # deepmath or tulu3\n    group_size: int = 4\n    groups_per_batch: int = 1024\n    learning_rate: float = 1e-4\n    max_tokens: int = 4096\n    kl_penalty_coef: float = 1.0\n    kl_discount_factor: float = 0.0\n    lora_rank: int = 128\n    loss_fn: str = \"importance_sampling\"\n\nasync def cli_main(cli_config: CLIConfig):\n    renderer_name = await checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async(\n        model_name=cli_config.model_name, ...)\n\n    dataset_builder = PromptOnlyDatasetBuilder(\n        dataset_name=cli_config.dataset,\n        groups_per_batch=cli_config.groups_per_batch,\n        group_size=cli_config.group_size,\n        model_name_for_tokenizer=cli_config.model_name,\n        renderer_name=renderer_name,\n    )\n    teacher_config = TeacherConfig(base_model=cli_config.teacher_model)\n    dataset_config = DistillationDatasetConfig(\n        dataset_builder=dataset_builder,\n        teacher_config=teacher_config,\n        groups_per_batch=cli_config.groups_per_batch,\n    )\n    config = train_on_policy.Config(\n        dataset_configs=[dataset_config],\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        learning_rate=cli_config.learning_rate,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        kl_discount_factor=cli_config.kl_discount_factor,\n        loss_fn=cli_config.loss_fn,\n        log_path=\"/tmp/tinker-examples/distillation/my_run\",\n    )\n    await train_on_policy.main(config)\n```\n\n## Step 5: Run\n\n```bash\n# On-policy distillation (reasoning)\npython -m tinker_cookbook.recipes.distillation.on_policy_distillation \\\n    model_name=Qwen/Qwen3-8B-Base dataset=deepmath learning_rate=1e-4\n\n# Off-policy reasoning (SFT on traces)\npython -m tinker_cookbook.recipes.distillation.off_policy_reasoning \\\n    model_name=Qwen/Qwen3-8B-Base learning_rate=2e-4\n\n# Multi-teacher\npython -m tinker_cookbook.recipes.distillation.on_policy_multi_teacher \\\n    model_name=Qwen/Qwen3-8B-Base learning_rate=1e-4\n```\n\n## Step 6: Add tests\n\nIf you created a new distillation recipe, add a smoke test:\n\n```python\n# tests/recipes/test_recipe_<name>.py\nimport pytest\nfrom tests.helpers import run_recipe\n\n@pytest.mark.integration\ndef test_<recipe_name>():\n    run_recipe(\n        \"tinker_cookbook.recipes.<recipe_name>.train\",\n        [\"behavior_if_log_dir_exists=delete\", \"groups_per_batch=16\"],\n    )\n```\n\n`run_recipe()` automatically passes `max_steps=2` so the recipe runs 2 training steps and exits. See `tests/recipes/test_recipe_on_policy_distillation.py` and `tests/recipes/test_recipe_on_policy_multi_teacher.py` for existing examples.\n\n## Step 7: Export weights (optional)\n\nAfter distillation, export the student model using the `tinker_cookbook.weights` API:\n\n```python\nfrom tinker_cookbook import weights\n\nadapter_dir = weights.download(tinker_path=\"tinker://run-id/sampler_weights/final\", output_dir=\"./adapter\")\nweights.build_hf_model(base_model=\"Qwen/Qwen3-8B-Base\", adapter_path=adapter_dir, output_path=\"./model\")\nweights.publish_to_hf_hub(model_path=\"./model\", repo_id=\"user/my-distilled-model\")\n```\n\n## Common pitfalls\n- Teacher model must be compatible with student's tokenizer/renderer\n- On-policy is generally better than off-policy but more compute-intensive\n- `kl_discount_factor=0.0` means no discounting — increase for longer sequences\n- High `kl_penalty_coef` can make training too conservative\n- For multi-teacher, ensure `groups_per_batch` is balanced across datasets\n"
  },
  {
    "path": ".claude/skills/dpo/SKILL.md",
    "content": "---\nname: dpo\ndescription: Set up and run Direct Preference Optimization (DPO) training on preference datasets using the Tinker API. Use when the user wants to train with preference data, chosen/rejected pairs, or DPO.\nargument-hint: \"[model-name] [dataset]\"\n---\n\n# Direct Preference Optimization (DPO)\n\nHelp the user set up and run DPO training using the Tinker API.\n\n## Step 1: Understand the request\n\nAsk the user (if not already specified):\n- **Model**: Which model to train (e.g., `meta-llama/Llama-3.2-1B`, `Qwen/Qwen3-8B`)\n- **Dataset**: Which preference dataset — built-in (HHH, HelpSteer3, UltraFeedback) or custom\n- **Starting checkpoint**: Train from base model or from an SFT checkpoint\n\n## Step 2: Reference existing recipes\n\nRead these files for patterns:\n- `tinker_cookbook/recipes/preference/dpo/train.py` — DPO CLI with built-in datasets\n- `tinker_cookbook/preference/train_dpo.py` — Core DPO training loop\n- `tinker_cookbook/preference/dpo_datasets.py` — DPO dataset builders\n- `tinker_cookbook/recipes/preference/datasets.py` — HHH, HelpSteer3, UltraFeedback builders\n- `docs/preferences/dpo-guide.mdx` — DPO guide\n\n## Step 3: Configure the training run\n\n### Key Parameters\n\n- `dpo_beta`: Controls how much the model deviates from reference. **Start with 0.1** (recommended default).\n  - Lower beta = more deviation from reference (more aggressive optimization)\n  - Higher beta = stays closer to reference (more conservative)\n- `learning_rate`: Typically **1e-5** for DPO (lower than SFT)\n- `lr_schedule`: `\"linear\"` decay is standard\n- `batch_size`: Number of tokens per batch (default: 256)\n- `max_length`: Maximum sequence length (default: 8192)\n- `reference_model_name`: Explicit reference model (defaults to the base model)\n\n### Preference Datasets\n\n**Built-in:**\n- `\"hhh\"` — Anthropic HHH (Helpful, Harmless, Honest) comparisons\n- `\"helpsteer3\"` — NVIDIA HelpSteer3 preference data\n- `\"ultrafeedback\"` — UltraFeedback preference data\n\n**Custom:** Create a `ComparisonBuilder` that yields `(chosen, rejected)` conversation pairs. See `recipes/preference/datasets.py` for examples.\n\n### Dataset Construction\n```python\nfrom tinker_cookbook.preference.dpo_datasets import DPODatasetBuilderFromComparisons\nfrom tinker_cookbook.recipes.preference.datasets import HHHComparisonBuilder\n\ncommon_config = ChatDatasetBuilderCommonConfig(\n    model_name_for_tokenizer=model_name,\n    renderer_name=renderer_name,\n    max_length=8192,\n    batch_size=256,\n)\ndataset = DPODatasetBuilderFromComparisons(\n    common_config=common_config,\n    comparison_builder=HHHComparisonBuilder(),\n)\n```\n\n## Step 4: Write the training script\n\nFollow the pattern from `recipes/preference/dpo/train.py`:\n\n```python\nimport chz\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.preference import train_dpo\nfrom tinker_cookbook.preference.dpo_datasets import DPODatasetBuilderFromComparisons\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilderCommonConfig\n\nconfig = train_dpo.Config(\n    log_path=\"/tmp/tinker-examples/dpo/my_run\",\n    model_name=\"meta-llama/Llama-3.2-1B\",\n    renderer_name=renderer_name,\n    dataset_builder=dataset,\n    learning_rate=1e-5,\n    lr_schedule=\"linear\",\n    dpo_beta=0.1,\n    reference_model_name=None,  # Uses base model as reference\n    load_checkpoint_path=None,  # Or path to SFT checkpoint\n)\n\ntrain_dpo.main(config)\n```\n\n## Step 5: Run\n\n```bash\n# Basic DPO with HHH dataset\npython -m tinker_cookbook.recipes.preference.dpo.train dataset=hhh\n\n# With different model and dataset\npython -m tinker_cookbook.recipes.preference.dpo.train \\\n    model_name=meta-llama/Llama-3.1-8B \\\n    dataset=ultrafeedback \\\n    dpo_beta=0.1 \\\n    learning_rate=1e-5\n\n# From an SFT checkpoint\npython -m tinker_cookbook.recipes.preference.dpo.train \\\n    load_checkpoint_path=/tmp/tinker-examples/sft/checkpoint_100\n```\n\n## Step 6: Add tests\n\nIf you created a new DPO recipe, add a smoke test:\n\n```python\n# tests/recipes/test_recipe_<name>.py\nimport pytest\nfrom tests.helpers import run_recipe\n\n@pytest.mark.integration\ndef test_<recipe_name>():\n    run_recipe(\n        \"tinker_cookbook.recipes.<recipe_name>.train\",\n        [\"behavior_if_log_dir_exists=delete\"],\n    )\n```\n\n`run_recipe()` automatically passes `max_steps=2` so the recipe runs 2 training steps and exits. See `tests/recipes/test_recipe_dpo.py` for the existing example.\n\n## Step 7: Export weights (optional)\n\nAfter DPO training, export weights using the `tinker_cookbook.weights` API:\n\n```python\nfrom tinker_cookbook import weights\n\nadapter_dir = weights.download(tinker_path=\"tinker://run-id/sampler_weights/final\", output_dir=\"./adapter\")\nweights.build_hf_model(base_model=\"meta-llama/Llama-3.2-1B\", adapter_path=adapter_dir, output_path=\"./model\")\nweights.publish_to_hf_hub(model_path=\"./model\", repo_id=\"user/my-dpo-model\")\n```\n\n## Common pitfalls\n- **Start with `dpo_beta=0.1`** — this is well-tested. Tune from there.\n- DPO LR should be **lower than SFT** (1e-5 vs 2e-4)\n- DPO works best when starting from an SFT checkpoint, not a raw base model\n- Reference model defaults to the base model — set `reference_model_name` explicitly if you want a different reference\n- Preference data quality matters more than quantity — ensure chosen/rejected pairs have clear quality differences\n"
  },
  {
    "path": ".claude/skills/environments/SKILL.md",
    "content": "---\nname: environments\ndescription: Guide for defining RL environments — the Env protocol, EnvGroupBuilder, RLDataset, and custom environment creation. Use when the user asks about RL environments, reward functions, or how to define custom tasks for RL training.\n---\n\n# RL Environments\n\nRL training requires environments that provide observations and rewards. This skill covers how to define and use them.\n\n## Reference\n\nRead these for details:\n- `tinker_cookbook/rl/types.py` — Env, EnvGroupBuilder, RLDatasetBuilder, Trajectory\n- `docs/rl/rl-envs.mdx` — Custom environments guide\n- `tinker_cookbook/recipes/math_rl/math_env.py` — Math environment example\n- `tinker_cookbook/recipes/harbor_rl/harbor_env.py` — Multi-turn sandbox environment\n- `tinker_cookbook/rl/message_env.py` — Message-based environment interface\n- `CONTRIBUTING.md` — Env lifecycle and design conventions\n\n## Core types\n\n### Env (single-use, no reset)\n\n```python\nfrom tinker_cookbook.rl.types import Env, Observation, Action, StepResult, StopCondition\n\nclass MyEnv(Env):\n    async def initial_observation(self) -> tuple[Observation, StopCondition]:\n        \"\"\"Return the initial prompt and stop condition.\"\"\"\n        model_input = renderer.build_generation_prompt(messages)\n        stop = renderer.get_stop_sequences()\n        return model_input, stop\n\n    async def step(self, action: Action) -> StepResult:\n        \"\"\"Process model output and return next observation + reward.\"\"\"\n        # action is TokensWithLogprobs (tokens + logprobs)\n        return StepResult(\n            observation=next_model_input,\n            stop_condition=stop,\n            reward=reward_value,\n            episode_done=True,\n            metrics={\"accuracy\": 1.0},\n        )\n```\n\n**Important:** Env objects are **single-use** — no reset method. Create fresh envs via EnvGroupBuilder each batch.\n\n### EnvGroupBuilder\n\nCreates a group of envs for the same prompt/task. Advantages are centered within each group (GRPO).\n\n```python\nfrom tinker_cookbook.rl.types import EnvGroupBuilder, TrajectoryGroup\n\nclass MyEnvGroupBuilder(EnvGroupBuilder):\n    async def make_envs(self) -> Sequence[Env]:\n        \"\"\"Return group_size envs for the same task.\"\"\"\n        return [MyEnv(problem=self.problem) for _ in range(self.group_size)]\n\n    async def compute_group_rewards(\n        self, trajectory_group: list[Trajectory], env_group: Sequence[Env]\n    ) -> list[tuple[float, Metrics]]:\n        \"\"\"Compute final rewards for each trajectory in the group.\"\"\"\n        return [(env.reward, {\"solved\": env.reward > 0}) for env in env_group]\n\n    def logging_tags(self) -> list[str]:\n        return [\"my_task\"]\n```\n\n### RLDatasetBuilder\n\nBuilds train/test datasets of EnvGroupBuilders:\n\n```python\n@chz.chz\nclass MyDatasetBuilder(RLDatasetBuilder):\n    batch_size: int = 128\n    group_size: int = 4\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset | None]:\n        # Return (train_dataset, optional_test_dataset)\n        ...\n```\n\n## Key data types\n\n```python\n@dataclass\nclass Transition:\n    ob: Observation       # ModelInput\n    ac: TokensWithLogprobs  # Action with logprobs\n    reward: float\n    episode_done: bool\n\n@dataclass\nclass Trajectory:\n    transitions: list[Transition]\n    final_ob: Observation\n\n@dataclass\nclass TrajectoryGroup:\n    trajectories_G: list[Trajectory]\n    final_rewards_G: list[float]\n    metrics_G: list[Metrics]\n```\n\n## Patterns\n\n### Single-turn (math, classification)\nModel generates one response, gets a reward. See `recipes/math_rl/math_env.py`.\n\n### Multi-turn (tool use, sandbox)\nModel generates, environment responds, repeat. See `recipes/harbor_rl/harbor_env.py` and `docs/rl/sequence-extension.mdx` for KV-cache support.\n\n### Multiplayer (games)\nGroup of envs represents a game — envs within the group interact. See `recipes/multiplayer_rl/text_arena/env.py`.\n\n### Preference-based (RLHF)\nGroup of envs generates completions, preference model scores pairs. See `tinker_cookbook/rl/preference_envs.py`.\n\n## Pluggable rollout executor\n\nFor scaling rollout collection, `train.main()` accepts an optional `rollout_executor` parameter:\n\n```python\nfrom concurrent.futures import ProcessPoolExecutor\nfrom tinker_cookbook.rl.train import main\n\nawait main(config, rollout_executor=ProcessPoolExecutor(max_workers=4))\n```\n\nEnvGroupBuilders must be **pickleable** for distributed execution. Test with `tinker_cookbook/rl/builder_pickle_test.py`.\n\n## Dimension conventions\n\n- `_P` = problems (different prompts/tasks)\n- `_G` = groups (multiple rollouts per problem)\n- `_T` = tokens (sequence position)\n- `_D` = datums (training data items)\n\nExample: `tokens_P_G_T[p][g][t]` = token `t` of group `g` of problem `p`.\n\n## Common pitfalls\n- Envs are **single-use** — always create fresh ones via EnvGroupBuilder\n- Advantages are centered within each group — `group_size` affects variance reduction\n- EnvGroupBuilders must be pickleable for distributed rollout execution\n- Shared resources (DB connections, sandboxes) should be managed by the builder, not the env\n- For multi-turn envs, use `max_steps_off_policy` for async rollouts when env execution is slow\n"
  },
  {
    "path": ".claude/skills/evals/SKILL.md",
    "content": "---\nname: evals\ndescription: Guide for evaluation — inline evaluators, Inspect AI integration, and custom evaluators for measuring training progress. Use when the user asks about evaluation, metrics, benchmarks, or how to measure model quality during training.\n---\n\n# Evaluation\n\nTraining scripts support inline evaluation at configurable intervals. The cookbook provides several evaluator patterns.\n\n## Reference\n\nRead these for details:\n- `docs/evals.mdx` — Evaluation guide\n- `tinker_cookbook/supervised/train.py` — SL evaluator integration (search for `evaluator_builders`)\n- `tinker_cookbook/rl/train.py` — RL evaluator integration\n- `tinker_cookbook/recipes/chat_sl/train.py` — Example with Inspect AI evaluators\n\n## Evaluator types\n\n### SL evaluators\nSL training supports two evaluator tiers:\n\n```python\nconfig = supervised_train.Config(\n    evaluator_builders=[...],              # Run every eval_every steps\n    infrequent_evaluator_builders=[...],   # Run every infrequent_eval_every steps\n    eval_every=8,\n    infrequent_eval_every=50,\n)\n```\n\n### RL evaluators\nRL training uses `SamplingClientEvaluator`:\n\n```python\nasync def my_evaluator(sampling_client: SamplingClient) -> dict[str, float]:\n    # Generate samples, compute metrics\n    return {\"accuracy\": 0.85, \"avg_length\": 150}\n\nconfig = rl_train.Config(\n    evaluator_builders=[my_evaluator],\n    eval_every=20,\n)\n```\n\n### RL test set evaluator\nEvaluates the policy on a held-out test set of environments:\n\n```python\n# Built into rl/train.py via test_dataset from RLDatasetBuilder\n# RLDatasetBuilder.__call__() returns (train_dataset, test_dataset)\n```\n\n## Inspect AI integration\n\nThe cookbook integrates with [Inspect AI](https://inspect.ai) for standard benchmarks:\n\n```python\nfrom tinker_cookbook.eval.inspect_utils import InspectAPIFromTinkerSampling\n\n# Create an Inspect evaluator that uses Tinker sampling\nevaluator = InspectAPIFromTinkerSampling(\n    task=\"gsm8k\",          # Inspect task name\n    renderer_name=renderer_name,\n    model_name=model_name,\n    include_reasoning=True,  # Include reasoning traces\n)\n```\n\nSee `tinker_cookbook/recipes/chat_sl/train.py` for a working example with GSM8K and IFEval.\n\n## Custom evaluators\n\n### Pattern 1: Sampling-based evaluation\n\n```python\nasync def eval_math(sampling_client: SamplingClient) -> dict[str, float]:\n    correct = 0\n    total = 100\n    for problem in test_problems:\n        response = sampling_client.sample(\n            prompt=problem.prompt,\n            num_samples=1,\n            sampling_params=SamplingParams(max_tokens=256, temperature=0.0),\n        )\n        answer = parse_answer(response.sequences[0].tokens)\n        if answer == problem.expected:\n            correct += 1\n    return {\"math_accuracy\": correct / total}\n```\n\n### Pattern 2: NLL-based evaluation\n\nCompute negative log-likelihood on a held-out dataset without generating text. See `tinker_cookbook/supervised/train.py` for the built-in NLL evaluator.\n\n## Metrics logging\n\n```python\nfrom tinker_cookbook.utils.ml_log import log_metrics\n\nlog_metrics({\"train/loss\": 0.5, \"eval/accuracy\": 0.85}, step=100)\n```\n\n## Common pitfalls\n- Evaluators run inline during training — keep them fast to avoid stalling the training loop\n- Use `infrequent_evaluator_builders` for expensive evals (large benchmarks)\n- RL evaluators receive a SamplingClient — create completers from it if needed\n- For Inspect AI, set `include_reasoning=True` to capture thinking traces\n"
  },
  {
    "path": ".claude/skills/grpo/SKILL.md",
    "content": "---\nname: grpo\ndescription: Set up and run reinforcement learning with verifiable rewards (RLVR/GRPO) for math, code, or custom environments using the Tinker API. Use when the user wants to do RL training, GRPO, reward-based optimization, or train with verifiable rewards.\nargument-hint: \"[model-name] [environment]\"\n---\n\n# Group Relative Policy Optimization (GRPO / RL)\n\nHelp the user set up and run RL training with verifiable rewards using the Tinker API.\n\n## Step 1: Understand the request\n\nAsk the user (if not already specified):\n- **Model**: Which model to train (e.g., `meta-llama/Llama-3.1-8B-Instruct`, `Qwen/Qwen3-8B`)\n- **Environment/Task**: What type of reward signal — math (GSM8K, DeepMath, arithmetic), code (DeepCoder), instruction following (IFBench), or custom\n- **Reward type**: Verifiable (programmatic correctness) or learned (preference model)\n\n## Step 2: Reference existing recipes\n\nRead these files for patterns:\n- `tinker_cookbook/recipes/rl_basic.py` — Minimal RL example (GSM8K)\n- `tinker_cookbook/recipes/math_rl/train.py` — Full math RL with multiple environments and loss functions\n- `tinker_cookbook/recipes/code_rl/train.py` — Code generation RL with sandbox execution\n- `tinker_cookbook/recipes/rubric/train.py` — Rubric-graded RL with LLM scoring\n- `tinker_cookbook/rl/train.py` — Core RL training loop\n- `tinker_cookbook/rl/types.py` — Env, EnvGroupBuilder, RLDatasetBuilder\n- `docs/rl/rl-basic.mdx` — Getting started\n- `docs/rl/rl-envs.mdx` — Custom environments\n- `docs/rl/rl-hyperparams.mdx` — Hyperparameter guidance\n\n## Step 3: Configure the training run\n\n### Environment Setup\nRL requires an environment that produces rewards. Key patterns:\n\n**Built-in environments:**\n- `Gsm8kDatasetBuilder` — Grade-school math (from `recipes/math_rl/math_env.py`)\n- `ArithmeticDatasetBuilder` — Simple arithmetic\n- `DeepMathDatasetBuilder`, `PolarisDatasetBuilder` — Advanced math\n- `DeepCoderDatasetBuilder` — Code generation with sandbox\n- `RubricDatasetBuilder` — Rubric-graded tasks\n\n**Custom environments:**\nImplement the `Env` protocol from `tinker_cookbook/rl/types.py`. Key points:\n- `Env` objects are **single-use** (no reset method)\n- Create new envs via `EnvGroupBuilder` each batch\n- Each env returns a `float` reward\n\n### Key Hyperparameters\n\n- `group_size`: Number of rollouts per prompt (typically 4-16). Advantages are centered within each group.\n- `groups_per_batch` (or `batch_size`): Number of problems per batch\n- `max_tokens`: Maximum generation length\n- `learning_rate`: Typically 1e-5 to 4e-5 for RL\n- `kl_penalty_coef`: KL penalty against reference model (0.0 = no penalty)\n- `temperature`: Sampling temperature (default 1.0)\n- `rollout_error_tolerance`: Tolerance for rollout errors (`False` = crash on any error, `True` = retry failed trajectories with default budget, `RetryOnFailure(max_retries=5)` = custom retry budget). Error counts logged as `rollout_errors/*` metrics.\n\n### Loss Functions\n- `importance_sampling` — Default, on-policy\n- `ppo` — Proximal Policy Optimization (clipped)\n- `cispo` — Conservative Importance Sampling PPO\n- `dro` — Distributionally Robust Optimization\n- Configure via `loss_fn` and `loss_fn_config` parameters\n\n### Async Training (Off-Policy)\nFor overlapping sampling and training:\n```python\nasync_config=AsyncConfig(\n    max_steps_off_policy=cli_config.max_steps_off_policy,\n    groups_per_batch=cli_config.groups_per_batch,\n)\n```\n\n## Step 4: Write the training script\n\nFollow the pattern from `rl_basic.py` / `math_rl/train.py`:\n\n```python\nimport asyncio\nimport chz\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.rl import train\n\ndef build_config_blueprint() -> chz.Blueprint[train.Config]:\n    model_name = \"meta-llama/Llama-3.1-8B-Instruct\"\n    renderer_name = model_info.get_recommended_renderer_name(model_name)\n\n    # Configure your dataset builder with environment\n    builder = ...  # e.g., Gsm8kDatasetBuilder(...)\n\n    return chz.Blueprint(train.Config).apply({\n        \"model_name\": model_name,\n        \"renderer_name\": renderer_name,\n        \"log_path\": \"/tmp/tinker-examples/my_rl_run\",\n        \"dataset_builder\": builder,\n        \"learning_rate\": 4e-5,\n        \"max_tokens\": 256,\n        \"eval_every\": 20,\n    })\n\ndef main(config: train.Config):\n    cli_utils.check_log_dir(config.log_path, behavior_if_exists=\"ask\")\n    asyncio.run(train.main(config))\n```\n\nFor the full CLI pattern with `@chz.chz` config class, see `recipes/math_rl/train.py`.\n\n## Step 5: Run\n\n```bash\npython -m tinker_cookbook.recipes.<recipe_name>\n```\n\nOverride: `python -m tinker_cookbook.recipes.<recipe_name> env=gsm8k group_size=16 learning_rate=4e-5`\n\n## Step 6: Add tests\n\nIf you created a new recipe, add a smoke test so CI catches regressions:\n\n```python\n# tests/recipes/test_recipe_<name>.py\nimport pytest\nfrom tests.helpers import run_recipe\n\n@pytest.mark.integration\ndef test_<recipe_name>():\n    run_recipe(\n        \"tinker_cookbook.recipes.<recipe_name>.train\",\n        [\"behavior_if_log_dir_exists=delete\", \"groups_per_batch=4\", \"group_size=2\"],\n    )\n```\n\n`run_recipe()` automatically passes `max_steps=2` so the recipe runs 2 training steps and exits. For environment logic (reward grading, env setup), add unit tests as `*_test.py` next to the source:\n- Example: `tinker_cookbook/recipes/math_rl/math_env_test.py`\n\n## Step 7: Export weights (optional)\n\nAfter training, export weights using the `tinker_cookbook.weights` API:\n\n```python\nfrom tinker_cookbook import weights\n\nadapter_dir = weights.download(tinker_path=\"tinker://run-id/sampler_weights/final\", output_dir=\"./adapter\")\nweights.build_hf_model(base_model=\"meta-llama/Llama-3.1-8B-Instruct\", adapter_path=adapter_dir, output_path=\"./model\")\nweights.publish_to_hf_hub(model_path=\"./model\", repo_id=\"user/my-finetuned-model\")\n```\n\n## Common pitfalls\n- `Env` objects are single-use — always create fresh envs via builder\n- Advantages are centered within each group — `group_size` matters for variance reduction\n- `max_tokens` too small truncates reasoning; too large wastes compute\n- Start with small `groups_per_batch` for debugging, scale up for real runs\n- Use `num_substeps > 1` for very large batches to split optimizer steps\n"
  },
  {
    "path": ".claude/skills/hyperparams/SKILL.md",
    "content": "---\nname: hyperparams\ndescription: Guide for hyperparameter selection — learning rate formulas, LoRA rank, batch size, group size, schedules, and model-specific tuning. Use when the user asks about learning rate, batch size, hyperparameter tuning, or how to configure training parameters.\n---\n\n# Hyperparameter Selection\n\nGuide for choosing training hyperparameters across SL, RL, DPO, and distillation.\n\n## Reference\n\n- `docs/supervised-learning/sl-hyperparams.mdx` — SL hyperparameter guide with LR formula\n- `docs/rl/rl-hyperparams.mdx` — RL hyperparameters (batch_size, group_size, num_substeps, async)\n- `tinker_cookbook/hyperparam_utils.py` — LR formulas and model-specific calculations\n\n## Learning rate\n\n### The formula\n\nThe recommended LR for a model `m` with LoRA:\n\n```\nLR(m) = lr_base × M_LoRA × (2000 / H_m) ^ P_m\n```\n\nWhere:\n- `lr_base = 5e-5`\n- `M_LoRA = 10` (1 for full fine-tuning)\n- `H_m` = hidden size of the model\n- `P_m` = model-specific exponent (0.0775 for Qwen, 0.781 for Llama)\n\n### Use the helper function\n\n```python\nfrom tinker_cookbook.hyperparam_utils import get_lr\n\nlr = get_lr(\"meta-llama/Llama-3.1-8B\", is_lora=True)\n# Returns model-specific recommended LR\n```\n\nThis formula gives <0.5% regret vs exhaustive sweeps across diverse SFT experiments.\n\n### Rules of thumb\n\n| Training type | Typical LR range | Notes |\n|---------------|------------------|-------|\n| SL (LoRA) | 1e-4 to 5e-4 | Use `get_lr()` |\n| SL (full FT) | 1e-5 to 5e-5 | LoRA LR / 10 |\n| RL | 1e-5 to 4e-5 | Lower than SL |\n| DPO | ~1e-5 | Much lower than SL |\n| RLHF (RL stage) | ~1e-5 | Same as RL |\n| Distillation | ~1e-4 | Similar to SL |\n\n## LoRA rank\n\n- **Default**: 32 for most tasks\n- **Higher rank** (64–128): More capacity, needed for complex tasks or large models\n- **Lower rank** (8–16): Faster, sufficient for simple adaptations\n- LR is **independent** of LoRA rank (validated empirically)\n\n```python\nfrom tinker_cookbook.hyperparam_utils import get_lora_param_count\n\n# Check parameter count for a given rank\nparams = get_lora_param_count(\"meta-llama/Llama-3.1-8B\", lora_rank=32)\n```\n\n## Batch size\n\n### SL batch size\n- Measured in **tokens**, not examples\n- **Recommended**: Start with 128\n- Smaller batch sizes often give better final performance at cost of longer training\n- Scale LR proportionally: `LR ∝ √batch_size`\n- Aim for at least 100 training steps (best results with 1000+)\n\n### RL batch size and group size\nTwo parameters control RL batch composition:\n\n- **`batch_size`** (or `groups_per_batch`): Number of unique problems/environments per batch\n- **`group_size`**: Number of rollouts per problem (advantages centered within group)\n\n```\ntotal_rollouts = batch_size × group_size\n```\n\nGuidelines:\n- If limited problems: increase `group_size` for more training signal\n- Scale LR with batch_size: `LR ∝ √batch_size`\n- Start small for debugging (`groups_per_batch=4, group_size=2`)\n\n## Learning rate schedule\n\nAvailable schedules:\n- `\"linear\"` — Linear decay to 0 (most common)\n- `\"cosine\"` — Cosine annealing\n- `\"constant\"` — No decay\n\nSet via `lr_schedule` parameter in config.\n\n## `num_substeps` (RL)\n\nControls how many optimizer updates per sampling iteration:\n\n- `num_substeps=1` (default): One update per batch — simplest, usually sufficient\n- `num_substeps>1`: Splits batch into mini-batches, one update each. Requires PPO objective.\n- Start with 2–4 if experimenting; decrease LR with higher values\n\n## DPO-specific\n\n- **`dpo_beta=0.1`** — Well-tested default. Controls deviation from reference model.\n- Lower beta = more aggressive optimization\n- Higher beta = stays closer to reference\n\n## Distillation-specific\n\n- **`kl_penalty_coef=1.0`** — Weight of KL penalty from teacher\n- **`kl_discount_factor=0.0`** — No discounting (increase for long sequences)\n\n## Quick-start recommendations\n\n| Scenario | Model | LR | Batch | LoRA Rank |\n|----------|-------|-----|-------|-----------|\n| SFT on chat data | Llama-3.1-8B | `get_lr(model)` | 128 | 32 |\n| Math GRPO | Llama-3.1-8B-Instruct | 4e-5 | 128×16 | 32 |\n| DPO | Llama-3.2-1B | 1e-5 | 256 | 32 |\n| Distillation | Qwen3-8B-Base | 1e-4 | 1024×4 | 128 |\n| Multi-turn RL | Kimi-K2-Thinking | 1e-5 | 8×4 | 32 |\n\n## Common pitfalls\n- LoRA needs ~10x higher LR than full fine-tuning — use `get_lr()` to get it right\n- `get_lr()` currently only supports Llama and Qwen families — other models need manual tuning\n- DPO LR should be much lower than SFT (1e-5 vs 2e-4)\n- RL LR should be lower than SFT — too aggressive updates destabilize the policy\n- Batch size too small = noisy gradients; too large = diminishing returns\n- Monitor KL divergence in RL — training is stable when KL < 0.01\n"
  },
  {
    "path": ".claude/skills/logging/SKILL.md",
    "content": "---\nname: logging\ndescription: Guide for training outputs, metrics logging, logtree reports, tracing/profiling, and debugging training runs. Use when the user asks about training logs, metrics, debugging, tracing, profiling, timing, Gantt charts, or understanding training output files.\n---\n\n# Logging & Debugging\n\nEvery training run writes structured outputs to `log_path`. This skill covers what's produced and how to use it.\n\n## Reference\n\n- `docs/rl/rl-logging.mdx` — Complete file reference for RL training outputs\n- `tinker_cookbook/utils/ml_log.py` — Metrics logging API\n- `tinker_cookbook/utils/logtree.py` — Logtree (structured rollout transcripts)\n- `tinker_cookbook/utils/trace.py` — Tracing/profiling (`@scope`, `trace_iteration`, Gantt charts)\n\n## Output files\n\nEach training run writes to its `log_path` directory:\n\n| File | Format | Contents |\n|------|--------|----------|\n| `metrics.jsonl` | JSONL | Scalar metrics per training iteration |\n| `config.json` | JSON | Full serialized training config (reproducibility) |\n| `checkpoints.jsonl` | JSONL | Checkpoint metadata (paths, loop state for resume) |\n| `code.diff` | text | Git diff at training start |\n| `train_iteration_NNNNNN.html` | HTML | Human-readable logtree report |\n| `train_iteration_NNNNNN_logtree.json` | JSON | Machine-readable rollout transcripts |\n| `train_iteration_NNNNNN_rollout_summaries.jsonl` | JSONL | Per-trajectory rewards and metrics |\n| `eval_<name>_iteration_NNNNNN.*` | mixed | Same formats for eval rollouts |\n| `timing_spans.jsonl` | JSONL | Per-iteration span timing data (from `trace_iteration`) |\n| `trace_events.jsonl` | JSONL | Perfetto/Chrome Trace format events (from `trace_init`) |\n| `gantt_NNNNNN.html` | HTML | Plotly Gantt chart of span timeline (optional) |\n\nIteration numbers are zero-padded to 6 digits.\n\n## Analyzing metrics\n\n```python\nimport pandas as pd\n\ndf = pd.read_json(\"path/to/log_path/metrics.jsonl\", lines=True)\ndf.plot(x=\"progress/batch\", y=\"env/all/reward/total\")\n```\n\n### Common metric keys\n\n**Progress:**\n- `progress/batch` — iteration index\n- `progress/done_frac` — completion fraction\n\n**RL rewards:**\n- `env/all/reward/total` — mean total reward\n- `env/all/<metric>` — env-emitted metrics (e.g., `correct`, `format_parse`)\n\n**Training health:**\n- `entropy` — per-token entropy\n- `kl_sample_train_v1`, `kl_sample_train_v2` — KL divergence (should stay < 0.01)\n- `optim/lr` — current learning rate\n- `ac_tokens_per_turn` — mean generated tokens per turn\n\n**Timing** (from `trace_iteration`):\n- `time/total` — iteration wall-clock duration\n- `time/<name>` — single-call duration (e.g., `time/train_step`)\n- `time/<name>:total`, `time/<name>:count`, `time/<name>:mean`, `time/<name>:max` — aggregates for functions called multiple times (e.g., `time/sample_async:total`)\n\n## Analyzing rollouts\n\n### Rollout summaries (aggregate)\n\n```python\nimport json\n\nwith open(\"train_iteration_000010_rollout_summaries.jsonl\") as f:\n    trajectories = [json.loads(line) for line in f]\n\nfor traj in trajectories:\n    print(f\"reward={traj['total_reward']:.2f}, metrics={traj['trajectory_metrics']}\")\n    # Each trajectory has: total_reward, final_reward, trajectory_metrics,\n    # steps (list of {ob_len, ac_len, reward, episode_done, metrics})\n```\n\n### Logtree JSON (full transcripts)\n\nContains full text of prompts, model responses, grading details. Walk the tree recursively looking for nodes with `data.type == \"conversation\"` to extract conversations. See `docs/rl/rl-logging.mdx` for the full schema.\n\n### HTML reports\n\nOpen `train_iteration_NNNNNN.html` in a browser for a human-readable view of rollouts with collapsible sections. `num_groups_to_log` (default: 4) controls how many trajectory groups get detailed logging.\n\n## Logging in your own code\n\n### Scalar metrics\n\n```python\nfrom tinker_cookbook.utils import ml_log\n\n# Set up logging (done once in training scripts)\nml_logger = ml_log.setup_logging(log_path=\"/tmp/my_run\", wandb_project=None, wandb_name=None)\n\n# Log scalar metrics\nml_logger.log_metrics({\"train/loss\": 0.5, \"eval/accuracy\": 0.85}, step=100)\n```\n\n### Logtree (structured transcripts)\n\n```python\nfrom tinker_cookbook.utils import logtree\n\nwith logtree.scope_header(\"my_section\"):\n    # Nested logging of rollouts, grading, etc.\n    ...\n```\n\n## Weights & Biases integration\n\nPass `wandb_project` and `wandb_name` in your config to enable W&B logging:\n\n```python\nconfig = train.Config(\n    wandb_project=\"my-project\",\n    wandb_name=\"my-experiment\",\n    ...\n)\n```\n\n## Tracing & profiling\n\nThe `tinker_cookbook/utils/trace` module provides per-iteration profiling across all training modules (RL, SL, DPO, distillation).\n\n### Core API\n\n```python\nfrom tinker_cookbook.utils import trace\n\n# Initialize Perfetto trace collector (optional — writes trace_events.jsonl)\ntrace.trace_init()\n\n# In training loop — collect per-iteration timing\nfor i_batch in range(n_batches):\n    with trace.trace_iteration(step=i_batch) as window:\n        # All @scope-decorated calls are automatically recorded\n        await gather_rollouts(...)\n        await train_step(...)\n\n    # Get timing metrics for this iteration\n    metrics.update(window.get_timing_metrics())\n\n    # Persist span data for post-hoc analysis\n    window.write_spans_jsonl(log_path / \"timing_spans.jsonl\", step=i_batch)\n\n    # Optional: Gantt chart visualization (requires plotly)\n    trace.save_gantt_chart_html(window, i_batch, log_path / f\"gantt_{i_batch}.html\")\n```\n\n### Instrumenting your code\n\n```python\nfrom tinker_cookbook.utils import trace\n\n# Decorator — automatically traces function calls\n@trace.scope\nasync def my_training_step(tc, batch):\n    result = await tc.forward_backward_async(data=batch, loss_fn=\"cross_entropy\")\n    return result\n\n# Inline span — for timing a code block without a dedicated function\nasync with trace.scope_span(\"data_prep\"):\n    batch = prepare_next_batch(...)\n\n# Sync variant\nwith trace.scope_span_sync(\"data_prep\"):\n    batch = prepare_next_batch(...)\n```\n\n`@scope` and `scope_span` are no-ops when called outside `trace_iteration` — safe to leave in production.\n\n### Viewing Perfetto traces\n\n```bash\n# Convert JSONL to JSON for visualization\nuv run python -m tinker_cookbook.utils.trace trace_events.jsonl trace.json\n# Open trace.json in chrome://tracing or https://ui.perfetto.dev/\n```\n\n## Debugging tips\n\n1. **Training not improving**: Check `metrics.jsonl` — is loss decreasing? Are rewards increasing?\n2. **KL divergence spiking**: KL > 0.01 indicates instability. Lower the learning rate.\n3. **Reward stuck at 0**: Check rollout summaries — are responses being parsed correctly?\n4. **OOM / timeout**: Reduce `batch_size`, `group_size`, or `max_tokens`\n5. **Shrink workloads for debugging**: Set small `batch_size`, `group_size`, and `max_steps`\n6. **Compare runs**: Load multiple `metrics.jsonl` into a DataFrame and overlay plots\n"
  },
  {
    "path": ".claude/skills/manage-skills/SKILL.md",
    "content": "---\nname: manage-skills\ndescription: Create, update, or organize Claude Code skills in this repo. Use when adding a new skill, reviewing existing skills for consistency, or maintaining the skill taxonomy.\ndisable-model-invocation: true\nargument-hint: \"[create|update|audit] [skill-name]\"\n---\n\n# Manage Claude Code Skills\n\nThis meta-skill governs how skills are created and maintained in the tinker-cookbook repo.\n\n## Skill taxonomy\n\nAll skills in `.claude/skills/` are organized into 5 layers:\n\n### Layer 0: Fundamentals (`setup`, `models`, `hyperparams`, `logging`)\n**Scope:** Getting started, model selection, hyperparameter guidance, training output analysis. Cross-cutting concerns needed before touching any code.\n**Auto-invocation:** Yes — triggers when users ask about setup, models, hyperparameters, or debugging.\n**Key principle:** These inform all other layers. Reference `docs/`, `README.md`, `tinker_cookbook/hyperparam_utils.py`.\n\n### Layer 1: Tinker SDK (`tinker-sdk`, `tinker-types`, `tinker-cli`)\n**Scope:** Raw Tinker Python SDK APIs — ServiceClient, TrainingClient, SamplingClient, RestClient, types, errors, and CLI commands.\n**Auto-invocation:** Yes — triggers when users ask about Tinker API basics or CLI usage.\n**Key principle:** Reference `docs/api-reference/` for authoritative API docs.\n\n### Layer 2: Cookbook Primitives (`renderers`, `environments`, `weights`, `completers`, `checkpoints`, `evals`, `datasets`)\n**Scope:** Building blocks in `tinker_cookbook/` — renderers, RL environments, weight lifecycle, completers, checkpointing, evaluators, dataset construction.\n**Auto-invocation:** Yes — triggers when users ask about specific primitives.\n**Key principle:** Reference source code in `tinker_cookbook/` and docs in `docs/`.\n\n### Layer 3: Algorithm / Task Recipes (`sft`, `grpo`, `distillation`, `dpo`, `rlhf`, `multiturn-rl`)\n**Scope:** End-to-end training workflows built on Layer 1 + Layer 2.\n**Auto-invocation:** Yes — triggers when users want to set up a specific training method.\n**Key principle:** Reference recipes in `tinker_cookbook/recipes/` and defer primitive details to Layer 2 skills.\n\n### Layer 4: Repo Development (`new-recipe`, `ci`, `contributing`, `manage-skills`)\n**Scope:** Development workflow — scaffolding, testing, CI, code style, skill maintenance.\n**Auto-invocation:** `contributing` and `ci` auto-invoke; `new-recipe` and `manage-skills` are manual-only.\n**Key principle:** Reference `CONTRIBUTING.md`, `tests/`, `.github/workflows/`.\n\n## Creating a new skill\n\n### Step 1: Determine the layer\nWhich layer does this skill belong to? Skills should have a clear, non-overlapping scope. If it spans layers, split it.\n\n### Step 2: Check for overlap\nRead existing skills in `.claude/skills/` to ensure the new skill doesn't duplicate content. If there's overlap, update the existing skill instead.\n\n### Step 3: Create the skill file\n\nCreate `.claude/skills/<skill-name>/SKILL.md` with this structure:\n\n```yaml\n---\nname: <skill-name>\ndescription: <Clear description of what the skill does and when to use it>\nargument-hint: \"[optional args]\"  # Only if the skill takes arguments\ndisable-model-invocation: true    # Only for manual-trigger skills (Layer 4 actions)\n---\n\n# <Skill Title>\n\n<Brief description of what this skill helps with>\n\n## Step 1: Understand the request\n<What to ask the user if not specified>\n\n## Step 2: Reference existing code\n<Which files to read for patterns — be specific with file paths>\n\n## Step 3: Key concepts\n<Core APIs, parameters, patterns>\n\n## Step 4: Implementation\n<Code examples following repo conventions>\n\n## Step N: Add tests\n<Testing guidance — smoke tests and unit tests>\n```\n\n### Step 4: Follow these conventions\n\n**Naming:**\n- Lowercase, hyphenated: `tinker-sdk`, `new-recipe`, `manage-skills`\n- Layer 0: named after the fundamental concept\n- Layer 1: named after the SDK concept\n- Layer 2: named after the primitive\n- Layer 3: named after the algorithm/method\n- Layer 4: named after the dev action\n\n**Content rules:**\n- Always reference **actual file paths** in the repo — never describe APIs from memory\n- Include code examples that follow repo conventions (`@chz.chz`, explicit typing, etc.)\n- For Layer 3 skills: defer primitive details to Layer 2 skills (e.g., say \"see `/renderers` skill\" instead of re-explaining renderers)\n- Include a testing section pointing to `tests/recipes/` for smoke tests and `*_test.py` for unit tests\n- Keep skills under 200 lines — move detailed reference material to separate files in the skill directory\n\n**Frontmatter rules:**\n- `description` is required and must clearly state **when** to trigger the skill\n- Use `disable-model-invocation: true` only for action-oriented Layer 4 skills\n- Use `argument-hint` if the skill takes positional arguments\n\n## Auditing existing skills\n\nWhen auditing, check each skill for:\n\n1. **Accuracy:** Do file paths and API references match the current codebase? Run `ls` or `grep` to verify.\n2. **Freshness:** Has the referenced code changed since the skill was written? Check git log for the referenced files.\n3. **Taxonomy compliance:** Is the skill in the correct layer? Does it overlap with other skills?\n4. **Convention compliance:** Does it follow the structure above? Does it include testing guidance?\n5. **Cross-references:** Do Layer 3 skills reference Layer 2 skills where appropriate?\n\n## Current skill inventory\n\n```\n.claude/skills/\n├── Layer 0: Fundamentals\n│   ├── setup/               # Installation, API key, first run\n│   ├── models/              # Model lineup, selection, families\n│   ├── hyperparams/         # LR formulas, batch size, LoRA rank\n│   └── logging/             # Training outputs, metrics, debugging\n├── Layer 1: SDK\n│   ├── tinker-sdk/          # ServiceClient, TrainingClient, SamplingClient, RestClient APIs\n│   ├── tinker-types/        # Datum, ModelInput, TensorData, response types, error types\n│   └── tinker-cli/          # tinker CLI: run/checkpoint management, download, publish\n├── Layer 2: Primitives\n│   ├── renderers/           # Renderer setup, TrainOnWhat, vision\n│   ├── environments/        # Env, EnvGroupBuilder, custom RL envs\n│   ├── weights/             # download, build_hf_model, publish\n│   ├── completers/          # TokenCompleter, MessageCompleter\n│   ├── checkpoints/         # save/load, CheckpointRecord, resume\n│   ├── evals/               # Evaluators, Inspect AI\n│   └── datasets/            # SupervisedDatasetBuilder, RLDatasetBuilder\n├── Layer 3: Recipes\n│   ├── sft/                 # Supervised fine-tuning\n│   ├── grpo/                # RL with verifiable rewards\n│   ├── distillation/        # Knowledge distillation\n│   ├── dpo/                 # Direct Preference Optimization\n│   ├── rlhf/                # RLHF pipeline\n│   └── multiturn-rl/        # Multi-turn RL\n└── Layer 4: Development\n    ├── new-recipe/          # Scaffold new recipe\n    ├── ci/                  # Testing and CI\n    ├── contributing/        # Dev setup and code style\n    └── manage-skills/       # This skill\n```\n\n## Maintenance schedule\n\nWhen the codebase changes significantly (new modules, API changes, renamed files):\n1. Run `/manage-skills audit` to check all skills\n2. Update affected skills\n3. Commit changes with a descriptive message\n"
  },
  {
    "path": ".claude/skills/models/SKILL.md",
    "content": "---\nname: models\ndescription: Guide for choosing models in Tinker — available model families, model types (base, instruction, reasoning, hybrid, vision), architecture (dense vs MoE), and how to match renderers to models. Use when the user asks which model to use, what models are available, or how to pick a model for their task.\n---\n\n# Model Selection\n\nHelp the user choose the right model for their task.\n\n## Reference\n\n- `docs/model-lineup.mdx` — Full model listing with types, sizes, and architecture\n- `tinker_cookbook/model_info.py` — Model metadata and renderer mapping\n\n## Available models\n\n### Qwen family\n| Model | Type | Arch | Size |\n|-------|------|------|------|\n| `Qwen/Qwen3.5-397B-A17B` | Hybrid + Vision | MoE | Large |\n| `Qwen/Qwen3.5-35B-A3B` | Hybrid + Vision | MoE | Medium |\n| `Qwen/Qwen3.5-27B` | Hybrid + Vision | Dense | Medium |\n| `Qwen/Qwen3.5-4B` | Hybrid + Vision | Dense | Compact |\n| `Qwen/Qwen3-235B-A22B-Instruct-2507` | Instruction | MoE | Large |\n| `Qwen/Qwen3-30B-A3B-Instruct-2507` | Instruction | MoE | Medium |\n| `Qwen/Qwen3-30B-A3B` | Hybrid | MoE | Medium |\n| `Qwen/Qwen3-30B-A3B-Base` | Base | MoE | Medium |\n| `Qwen/Qwen3-32B` | Hybrid | Dense | Medium |\n| `Qwen/Qwen3-8B` | Hybrid | Dense | Small |\n| `Qwen/Qwen3-8B-Base` | Base | Dense | Small |\n| `Qwen/Qwen3-4B-Instruct-2507` | Instruction | Dense | Compact |\n| `Qwen/Qwen3-VL-235B-A22B-Instruct` | Vision | MoE | Large |\n| `Qwen/Qwen3-VL-30B-A3B-Instruct` | Vision | MoE | Medium |\n\n### Llama family\n| Model | Type | Arch | Size |\n|-------|------|------|------|\n| `meta-llama/Llama-3.3-70B-Instruct` | Instruction | Dense | Large |\n| `meta-llama/Llama-3.1-70B` | Base | Dense | Large |\n| `meta-llama/Llama-3.1-8B` | Base | Dense | Small |\n| `meta-llama/Llama-3.1-8B-Instruct` | Instruction | Dense | Small |\n| `meta-llama/Llama-3.2-3B` | Base | Dense | Compact |\n| `meta-llama/Llama-3.2-1B` | Base | Dense | Compact |\n\n### Nemotron family\n| Model | Type | Arch | Size |\n|-------|------|------|------|\n| `nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16` | Hybrid | MoE | Large |\n| `nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16` | Hybrid | MoE | Medium |\n\n### Other families\n| Model | Type | Arch | Size |\n|-------|------|------|------|\n| `openai/gpt-oss-120b` | Reasoning | MoE | Medium |\n| `openai/gpt-oss-20b` | Reasoning | MoE | Small |\n| `deepseek-ai/DeepSeek-V3.1` | Hybrid | MoE | Large |\n| `deepseek-ai/DeepSeek-V3.1-Base` | Base | MoE | Large |\n| `moonshotai/Kimi-K2-Thinking` | Reasoning | MoE | Large |\n| `moonshotai/Kimi-K2.5` | Reasoning + Vision | MoE | Large |\n\n## How to choose\n\n### By task type\n\n- **Instruction tuning / chat SFT**: Start with an Instruction model (e.g., `Llama-3.1-8B-Instruct`, `Qwen3-30B-A3B-Instruct-2507`)\n- **RL with verifiable rewards (GRPO)**: Use Instruction or Hybrid models — they already follow instructions\n- **Reasoning / chain-of-thought**: Use Reasoning or Hybrid models (`Kimi-K2-Thinking`, `Qwen3-8B`)\n- **Full post-training pipeline**: Start with a Base model (e.g., `Qwen3-8B-Base`, `Llama-3.1-8B`)\n- **Vision tasks**: Use Vision or Hybrid+Vision models (`Qwen3.5-35B-A3B`, `Qwen3-VL-*`)\n- **Distillation (student)**: Use a Base model as student\n- **Quick prototyping**: Use compact models (`Llama-3.2-1B`, `Qwen3.5-4B`)\n\n### By cost\n\n**Prefer MoE models** — they're much more cost-effective than dense models because cost scales with active parameters, not total parameters. For example, `Qwen3-30B-A3B` (MoE, 3B active) is cheaper than `Qwen3-32B` (Dense, 32B active) despite similar quality.\n\n### Model types explained\n\n- **Base**: Pre-trained on raw text. For research or full post-training pipelines.\n- **Instruction**: Fine-tuned for instruction following. Fast inference, no chain-of-thought.\n- **Reasoning**: Always uses chain-of-thought before visible output.\n- **Hybrid**: Can operate in both thinking and non-thinking modes.\n- **Vision**: Processes images alongside text. See `/renderers` skill for vision input handling.\n\n### Size categories\n- **Compact**: 1B–4B parameters\n- **Small**: 8B parameters\n- **Medium**: 27B–32B parameters\n- **Large**: 70B+ parameters\n\n## Renderer matching\n\nEvery model needs a matching renderer. **Always use the automatic lookup**:\n\n```python\nfrom tinker_cookbook import model_info\n\nrenderer_name = model_info.get_recommended_renderer_name(model_name)\n```\n\nNever hardcode renderer names — the mapping is maintained in `model_info.py`.\n\n## Learning rate by model\n\nUse `hyperparam_utils.get_lr(model_name)` for model-specific LR recommendations. See the `/hyperparams` skill for details.\n\n## Common pitfalls\n- MoE models are cheaper than dense — prefer them unless you have a specific reason\n- Base models need full post-training (SFT + alignment) to be useful for chat\n- Instruction models are best for tasks where you want to start from a capable baseline\n- Vision models require `ImageChunk` in messages — see `/renderers` skill\n- Llama models require `HF_TOKEN` for tokenizer download (gated on HuggingFace)\n"
  },
  {
    "path": ".claude/skills/multiturn-rl/SKILL.md",
    "content": "---\nname: multiturn-rl\ndescription: Set up and run multi-turn RL training for interactive environments (terminal tasks, tool use, search/RAG, games) using the Tinker API. Use when the user wants multi-turn RL, agentic training, tool-use RL, or interactive environment training.\nargument-hint: \"[model-name] [environment-type]\"\n---\n\n# Multi-Turn RL Training\n\nHelp the user set up RL training for multi-turn interactive environments using the Tinker API.\n\n## Step 1: Understand the request\n\nAsk the user (if not already specified):\n- **Model**: Which model to train (e.g., `moonshotai/Kimi-K2-Thinking`, `Qwen/Qwen3-8B`)\n- **Environment type**:\n  - **Terminal/sandbox tasks**: Model executes shell commands (Harbor)\n  - **Search/RAG**: Model uses retrieval tools (Search-R1)\n  - **Multiplayer games**: Two models compete (TicTacToe, Twenty Questions, Guess Number)\n  - **Custom multi-turn**: User-defined interactive environment\n- **Turn structure**: Max turns, tool outputs, observation handling\n\n## Step 2: Reference existing recipes\n\nRead these files for patterns:\n- `tinker_cookbook/recipes/harbor_rl/train.py` — Terminal task RL with sandbox execution\n- `tinker_cookbook/recipes/harbor_rl/harbor_env.py` — HarborDatasetBuilder, sandbox factory\n- `tinker_cookbook/recipes/search_tool/train.py` — Search-R1 with Chroma vector DB\n- `tinker_cookbook/recipes/multiplayer_rl/text_arena/train.py` — Two-player games\n- `tinker_cookbook/recipes/multiplayer_rl/twenty_questions/train.py` — Twenty Questions\n- `tinker_cookbook/recipes/multiplayer_rl/guess_number/train.py` — Guess the Number\n- `tinker_cookbook/rl/message_env.py` — Message-based environment interface\n- `docs/rl/sequence-extension.mdx` — Multi-turn RL and KV-cache\n- `docs/rl/rl-envs.mdx` — Custom environments\n\n## Step 3: Configure the environment\n\n### Harbor (Terminal Tasks)\nInteractive sandbox where model runs shell commands and gets outputs:\n\n```python\nfrom tinker_cookbook.recipes.harbor_rl.harbor_env import HarborDatasetBuilder, HarborTask\n\ndataset_builder = HarborDatasetBuilder(\n    tasks=tasks,                    # List of HarborTask objects\n    batch_size=8,                   # groups_per_batch\n    group_size=4,                   # rollouts per task\n    model_name=model_name,\n    renderer_name=renderer_name,\n    max_turns=10,                   # max interaction turns\n    sandbox_timeout=3600,           # sandbox lifetime (seconds)\n    command_timeout=120,            # per-command timeout\n    grader_timeout=60,              # grading timeout\n)\n```\n\n### Search/RAG (Search-R1)\nModel queries a vector database during generation:\n\nSee `recipes/search_tool/train.py` for Chroma integration and streaming minibatch config.\n\n### Multiplayer Games\nTwo models play against each other:\n\nSee `recipes/multiplayer_rl/text_arena/train.py` for the competitive RL pattern.\n\n### Key Multi-Turn Parameters\n\n- `max_turns`: Maximum number of interaction turns\n- `max_tokens`: Max tokens per generation step\n- `kl_penalty_coef`: KL penalty (often 0.0 for multi-turn to allow exploration)\n- `max_steps_off_policy`: Enable async rollouts for expensive environments\n\n### Async Rollouts\nMulti-turn envs are slow due to tool execution. Use async config:\n```python\nconfig = Config(\n    ...\n    async_config=AsyncConfig(\n        max_steps_off_policy=cli_config.max_steps_off_policy,\n        groups_per_batch=cli_config.groups_per_batch,\n    ) if cli_config.max_steps_off_policy is not None else None,\n)\n```\n\n## Step 4: Write the training script\n\nFollow the Harbor pattern:\n\n```python\nimport asyncio\nimport chz\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.rl.train import AsyncConfig, Config, main\n\n@chz.chz\nclass CLIConfig:\n    model_name: str = \"moonshotai/Kimi-K2-Thinking\"\n    lora_rank: int = 32\n    max_tokens: int = 8192\n    max_turns: int = 10\n    group_size: int = 4\n    groups_per_batch: int = 8\n    learning_rate: float = 1e-5\n    kl_penalty_coef: float = 0.0\n    max_steps_off_policy: int | None = None\n\nasync def cli_main(cli_config: CLIConfig):\n    renderer_name = model_info.get_recommended_renderer_name(cli_config.model_name)\n\n    dataset_builder = ...  # Your multi-turn dataset builder\n\n    config = Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_builder=dataset_builder,\n        model_name=cli_config.model_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        log_path=\"/tmp/tinker-examples/multiturn/my_run\",\n        async_config=AsyncConfig(\n            max_steps_off_policy=cli_config.max_steps_off_policy,\n            groups_per_batch=cli_config.groups_per_batch,\n        ) if cli_config.max_steps_off_policy is not None else None,\n    )\n\n    await main(config)\n```\n\n## Step 5: Run\n\n```bash\n# Harbor terminal RL\npython -m tinker_cookbook.recipes.harbor_rl.train\n\n# Search-R1\npython -m tinker_cookbook.recipes.search_tool.train\n\n# Multiplayer games\npython -m tinker_cookbook.recipes.multiplayer_rl.text_arena.train\n```\n\n## Step 6: Add tests\n\nIf you created a new multi-turn recipe, add a smoke test:\n\n```python\n# tests/recipes/test_recipe_<name>.py\nimport pytest\nfrom tests.helpers import run_recipe\n\n@pytest.mark.integration\ndef test_<recipe_name>():\n    run_recipe(\n        \"tinker_cookbook.recipes.<recipe_name>.train\",\n        [\"behavior_if_log_dir_exists=delete\", \"groups_per_batch=4\", \"group_size=2\"],\n    )\n```\n\n`run_recipe()` automatically passes `max_steps=2` so the recipe runs 2 training steps and exits. See `tests/recipes/test_recipe_text_arena.py` and `tests/recipes/test_recipe_twenty_questions.py` for existing multi-turn examples. For environment-specific logic (sandbox setup, tool parsing), add unit tests as `*_test.py` next to the source code.\n\n## Common pitfalls\n- Multi-turn envs are expensive — start with small `groups_per_batch` (4-8)\n- Use `max_steps_off_policy` for async rollouts when env execution is slow\n- `Env` objects are single-use — the builder creates fresh envs each batch\n- Sandbox timeouts need to be generous enough for complex tasks\n- KV-cache (sequence extension) is key for multi-turn efficiency — see `docs/rl/sequence-extension.mdx`\n- `kl_penalty_coef=0.0` is common for multi-turn since you want the model to explore tool use\n"
  },
  {
    "path": ".claude/skills/new-recipe/SKILL.md",
    "content": "---\nname: new-recipe\ndescription: Scaffold a new training recipe for the Tinker cookbook following repo conventions. Use when the user wants to create a new recipe, training script, or experiment.\ndisable-model-invocation: true\nargument-hint: \"[recipe-name]\"\n---\n\n# Create a New Training Recipe\n\nScaffold a new training recipe in `tinker_cookbook/recipes/` following repo conventions.\n\n## Step 1: Understand the request\n\nAsk the user:\n- **Recipe name**: What to call it (becomes the directory/file name under `recipes/`)\n- **Training type**: SL, RL, DPO, distillation, or hybrid\n- **Key details**: Model, dataset, environment, reward signal, etc.\n\n## Step 2: Read existing recipes for patterns\n\nBefore writing any code, read the most relevant existing recipe:\n- **SL-based**: Read `tinker_cookbook/recipes/sl_basic.py` and `tinker_cookbook/recipes/chat_sl/train.py`\n- **RL-based**: Read `tinker_cookbook/recipes/rl_basic.py` and `tinker_cookbook/recipes/math_rl/train.py`\n- **DPO-based**: Read `tinker_cookbook/recipes/preference/dpo/train.py`\n- **Distillation-based**: Read `tinker_cookbook/recipes/distillation/on_policy_distillation.py`\n- **Multi-turn RL**: Read `tinker_cookbook/recipes/harbor_rl/train.py`\n\nAlso read `CLAUDE.md` for conventions.\n\n## Step 3: Follow repo conventions\n\nEvery recipe MUST follow these patterns:\n\n### File structure\n```\ntinker_cookbook/recipes/<recipe_name>/\n├── __init__.py        # Empty or minimal\n├── train.py           # Main entry point with CLIConfig + cli_main\n└── <env_or_data>.py   # Dataset/environment definitions (if needed)\n```\n\nOr for simple recipes: `tinker_cookbook/recipes/<recipe_name>.py`\n\n### CLI pattern (use `@chz.chz` for config)\n```python\n@chz.chz\nclass CLIConfig:\n    model_name: str = \"meta-llama/Llama-3.1-8B\"\n    learning_rate: float = 1e-4\n    # ... all configurable parameters with defaults\n\nasync def cli_main(cli_config: CLIConfig):\n    # Build full config from CLI config\n    # Call training main function\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config))\n```\n\n### Required elements\n1. **`@chz.chz` config class** with sensible defaults\n2. **`model_info.get_recommended_renderer_name(model_name)`** for renderer — never hardcode\n3. **`cli_utils.check_log_dir()`** before training to avoid clobbering\n4. **`checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async()`** if loading checkpoints\n5. **Explicit typing** — no `Any` or `type: ignore`\n6. **Auto-generated log paths** with model name, hyperparams, and timestamp\n\n### Naming conventions\n- Subscript suffixes for tensors: `_P` (problems), `_G` (groups), `_T` (tokens), `_D` (datums)\n- Use `safezip`, `timed`, `scope` helpers where appropriate\n- Use `ml_log.log_metrics` for metrics, `logtree` for transcripts\n\n### Entry point\nRecipe must be runnable as:\n```bash\npython -m tinker_cookbook.recipes.<recipe_name>.train [chz overrides]\n```\n\n## Step 4: Create the recipe\n\nWrite the recipe files following the patterns above. Place them in `tinker_cookbook/recipes/$ARGUMENTS/`.\n\n## Step 5: Add tests\n\nThe repo has two layers of testing. **Both should be added for every new recipe.**\n\n### Smoke test (required)\nCreate `tests/recipes/test_recipe_<name>.py` — a minimal test that runs the recipe for 2 training steps and verifies clean exit. CI auto-discovers these files and runs them daily.\n\n```python\nimport pytest\nfrom tests.helpers import run_recipe\n\n@pytest.mark.integration\ndef test_<recipe_name>():\n    run_recipe(\n        \"tinker_cookbook.recipes.<recipe_name>.train\",\n        [\n            \"behavior_if_log_dir_exists=delete\",\n            # Override params to make it fast:\n            # \"groups_per_batch=4\", \"group_size=2\", \"batch_size=16\", etc.\n        ],\n    )\n```\n\nKey conventions:\n- `run_recipe()` launches the module as a subprocess and automatically passes `max_steps=2` (configurable via the `max_steps` parameter)\n- The recipe runs for 2 training steps and exits naturally — the test passes on clean exit (exit code 0)\n- Always pass `behavior_if_log_dir_exists=delete` to avoid conflicts in repeated CI runs\n- Override batch sizes / group sizes to small values so the test completes quickly\n- Mark tests with `@pytest.mark.integration` — these require `TINKER_API_KEY`\n- See `tests/helpers.py` for `run_recipe()` details and `tests/conftest.py` for fixtures\n\n### Unit tests (for testable components)\nPlace unit tests next to the code they test using the `*_test.py` naming convention:\n\n```\ntinker_cookbook/recipes/<recipe_name>/<component>_test.py\n```\n\nFor example:\n- `tinker_cookbook/recipes/math_rl/math_env_test.py` — tests environment logic\n- `tinker_cookbook/renderers/parsing_test.py` — tests parsing helpers\n\nUnit tests should:\n- Run without `TINKER_API_KEY` (no network calls)\n- Be fast (< 1s per test)\n- Use standard pytest features (fixtures, parametrize, marks)\n- Test picklability if the component needs to be serialized for distributed rollout\n\n### Running tests locally\n\n```bash\n# Unit tests only (no API key needed)\nuv run pytest tinker_cookbook/\n\n# Integration / smoke tests (requires TINKER_API_KEY)\nuv run pytest tests/recipes/test_recipe_<name>.py -v -x -s\n```\n\n### CI integration\n- **Unit tests** (`pytest tinker_cookbook/`) run on every PR via `.github/workflows/pytest.yaml`\n- **Integration tests** (`pytest tests/`) run daily and on manual trigger via `.github/workflows/smoke-test-recipes.yaml`\n- Adding `tests/recipes/test_recipe_<name>.py` is all that's needed — CI auto-discovers it\n\n## Step 6: Verify\n\n- Ensure the recipe is importable: `python -c \"from tinker_cookbook.recipes.<name> import train\"`\n- Check that CLI help works: `python -m tinker_cookbook.recipes.<name>.train --help`\n- Run the smoke test locally: `uv run pytest tests/recipes/test_recipe_<name>.py -v -x -s`\n"
  },
  {
    "path": ".claude/skills/renderers/SKILL.md",
    "content": "---\nname: renderers\ndescription: Guide for using renderers — the bridge between chat-style messages and token sequences. Covers renderer setup, TrainOnWhat, vision inputs, model family matching, and custom renderers. Use when the user asks about renderers, tokenization, message formatting, or vision inputs.\n---\n\n# Renderers\n\nRenderers convert chat-style messages into token sequences for training and generation.\n\n## Reference\n\nRead these for details:\n- `tinker_cookbook/renderers/base.py` — Renderer base class and API\n- `tinker_cookbook/renderers/__init__.py` — Registry, factory, TrainOnWhat enum\n- `docs/rendering.mdx` — Rendering guide with examples\n\n## Getting a renderer\n\nAlways use `model_info.get_recommended_renderer_name()` — never hardcode:\n\n```python\nfrom tinker_cookbook import model_info\nfrom tinker_cookbook.renderers import get_renderer\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nrenderer_name = model_info.get_recommended_renderer_name(model_name)\ntokenizer = get_tokenizer(model_name)\nrenderer = get_renderer(renderer_name, tokenizer)\n```\n\n**Available renderers:** `llama3`, `qwen3`, `deepseekv3`, `kimi_k2`, `kimi_k25`, `nemotron3`, `nemotron3_disable_thinking`, `role_colon`, and more. See `tinker_cookbook/renderers/__init__.py` for the full registry.\n\n## Key renderer methods\n\n```python\n# Build generation prompt (for sampling)\nmodel_input = renderer.build_generation_prompt(messages, role=\"assistant\")\n\n# Build supervised example (for training)\nmodel_input, weights = renderer.build_supervised_example(\n    messages, train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES\n)\n\n# Parse model output back to a message\nmessage, is_complete = renderer.parse_response(token_ids)\n\n# Get stop sequences for sampling\nstop = renderer.get_stop_sequences()\n\n# Tool calling support\nprefix_messages = renderer.create_conversation_prefix_with_tools(tool_specs)\n```\n\n## TrainOnWhat\n\nControls which tokens receive training signal:\n\n```python\nfrom tinker_cookbook.renderers import TrainOnWhat\n\n# Most common — train on all assistant responses\nTrainOnWhat.ALL_ASSISTANT_MESSAGES\n\n# Train only on the final assistant response\nTrainOnWhat.LAST_ASSISTANT_MESSAGE\n\n# Train on everything (including user messages)\nTrainOnWhat.ALL_TOKENS\n\n# Other options\nTrainOnWhat.LAST_ASSISTANT_TURN\nTrainOnWhat.ALL_MESSAGES\nTrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES\nTrainOnWhat.CUSTOMIZED  # Set trainable=True/False on individual messages\n```\n\n## Vision inputs\n\nFor VLM models, use `ImageChunk` in messages:\n\n```python\nmessage = {\n    \"role\": \"user\",\n    \"content\": [\n        {\"type\": \"image\", \"image_url\": \"https://...\"},  # or local path\n        {\"type\": \"text\", \"text\": \"What is in this image?\"},\n    ],\n}\n```\n\nSee `docs/rendering.mdx` and `tinker_cookbook/recipes/vlm_classifier/train.py` for VLM examples.\n\n## Custom renderers\n\nRegister a custom renderer:\n\n```python\nfrom tinker_cookbook.renderers import register_renderer\n\ndef my_renderer_factory(tokenizer, image_processor):\n    return MyCustomRenderer(tokenizer)\n\nregister_renderer(\"my_renderer\", my_renderer_factory)\n```\n\n## Picklability\n\nRenderers must be pickleable for distributed rollout execution. The codebase tests this — see `tinker_cookbook/renderers/renderer_pickle_test.py`.\n\n## Common pitfalls\n- Always use `model_info.get_recommended_renderer_name()` — renderer must match model family\n- After loading a checkpoint trained with a specific renderer, use the same renderer name\n- `build_supervised_example()` returns weights as `list[float]` — wrap with `TensorData.from_numpy()` if needed\n- For tool calling, use `create_conversation_prefix_with_tools()` to inject tool definitions\n"
  },
  {
    "path": ".claude/skills/rlhf/SKILL.md",
    "content": "---\nname: rlhf\ndescription: Set up and run the full RLHF pipeline (SFT, reward model training, RL from reward model) using the Tinker API. Use when the user wants to do RLHF, train a reward model, or run the full preference-based RL pipeline.\nargument-hint: \"[model-name]\"\n---\n\n# RL from Human Feedback (RLHF) Pipeline\n\nHelp the user set up and run the full 3-stage RLHF pipeline using the Tinker API.\n\n## Overview\n\nRLHF is a multi-stage pipeline:\n1. **SFT Stage** — Fine-tune base model on instruction data\n2. **Reward Model (RM) Stage** — Train a reward model on preference comparisons\n3. **RL Stage** — Optimize the SFT policy using the reward model\n\n## Step 1: Understand the request\n\nAsk the user (if not already specified):\n- **Base model**: Which model to start from (e.g., `meta-llama/Llama-3.2-3B`)\n- **Preference data**: Which comparison dataset (HHH, HelpSteer3, UltraFeedback, or custom)\n- **Which stages to run**: All 3, or skip SFT/RM if checkpoints exist\n- **LoRA rank**: Typically 64 for RLHF\n\n## Step 2: Reference existing recipes\n\nRead these files:\n- `tinker_cookbook/recipes/preference/rlhf/rlhf_pipeline.py` — Complete 3-stage pipeline\n- `tinker_cookbook/rl/preference_envs.py` — Preference-based RL environments\n- `tinker_cookbook/preference/types.py` — PreferenceModelBuilder\n- `tinker_cookbook/preference/comparison_policy_evaluator.py` — RM evaluation\n- `docs/preferences/rlhf-example.mdx` — RLHF guide\n\n## Step 3: Configure each stage\n\n### Stage 1: SFT\nStandard supervised fine-tuning (see `/sft` skill). Key settings:\n- Dataset: NoRobots or similar instruction data\n- `sft_learning_rate`: 2e-4 (LoRA)\n- `train_on_what`: `TrainOnWhat.ALL_ASSISTANT_MESSAGES`\n\n### Stage 2: Reward Model\nTrain on preference comparisons:\n- Uses `ChatDatasetBuilderFromComparisons` with a comparison builder (e.g., `HHHComparisonBuilder`)\n- `rm_learning_rate`: 3e-4\n- Produces a reward model checkpoint used in Stage 3\n\n### Stage 3: RL from Reward Model\nOptimize SFT policy using RM scores:\n- Load SFT checkpoint as starting policy\n- Load RM weights for scoring\n- `PreferenceModelBuilderFromChatRenderer` wraps the RM\n- `PairwisePreferenceRLDatasetBuilder` creates the RL environment\n- `rl_learning_rate`: 1e-5 (much lower than SFT)\n- `tournament_pattern`: `ALL_PAIRS_BOTH_WAYS` for pairwise comparison\n\n### Typical Hyperparameters\n```python\n@chz.chz\nclass CLIConfig:\n    base_model: str = \"meta-llama/Llama-3.2-3B\"\n    lora_rank: int = 64\n    batch_size: int = 256\n    max_length: int = 16384\n    sft_learning_rate: float = 2e-4\n    rm_learning_rate: float = 3e-4\n    rl_learning_rate: float = 1e-5\n    rl_max_tokens: int = 1024\n    rl_group_size: int = 4\n```\n\n## Step 4: Write the training script\n\nFollow the pipeline pattern from `rlhf_pipeline.py`:\n\n```python\nimport asyncio\nimport os\nimport chz\nfrom tinker_cookbook import checkpoint_utils, model_info\nfrom tinker_cookbook.preference.types import PreferenceModelBuilderFromChatRenderer\nfrom tinker_cookbook.rl import preference_envs, train\nfrom tinker_cookbook.supervised import train as supervised_train\n\n# Stage 1: SFT\ndef sft_stage(log_path, base_model, ...):\n    # Standard SFT config + supervised_train.main()\n    ...\n\n# Stage 2: Reward Model\ndef train_rm(log_path, base_model, ...):\n    # Train on preference comparisons\n    ...\n\n# Stage 3: RL\nasync def train_rl(log_path, sft_log_path, rm_log_path, base_model, ...):\n    sft_checkpoint = checkpoint_utils.get_last_checkpoint(sft_log_path)[\"state_path\"]\n    rm_weights = checkpoint_utils.get_last_checkpoint(rm_log_path)[\"sampler_path\"]\n\n    preference_model_builder = PreferenceModelBuilderFromChatRenderer(\n        renderer_name=renderer_name,\n        model_name=base_model,\n        rm_weights_path=rm_weights,\n    )\n    rl_dataset_builder = preference_envs.PairwisePreferenceRLDatasetBuilder(\n        comparison_builder=comparison_builder,\n        preference_model_builder=preference_model_builder,\n        batch_size=batch_size,\n        group_size=group_size,\n        tournament_pattern=preference_envs.TournamentPattern.ALL_PAIRS_BOTH_WAYS,\n        ...\n    )\n    config = train.Config(\n        load_checkpoint_path=sft_checkpoint,\n        dataset_builder=rl_dataset_builder,\n        learning_rate=1e-5,\n        ...\n    )\n    await train.main(config)\n```\n\n## Step 5: Run\n\n```bash\n# Full pipeline\npython -m tinker_cookbook.recipes.preference.rlhf.rlhf_pipeline\n\n# Skip SFT (already have checkpoint)\npython -m tinker_cookbook.recipes.preference.rlhf.rlhf_pipeline run_sft=False\n\n# Skip SFT and RM\npython -m tinker_cookbook.recipes.preference.rlhf.rlhf_pipeline run_sft=False run_rm=False\n```\n\n## Step 6: Add tests\n\nIf you created a new RLHF recipe, add a smoke test:\n\n```python\n# tests/recipes/test_recipe_<name>.py\nimport pytest\nfrom tests.helpers import run_recipe\n\n@pytest.mark.integration\ndef test_<recipe_name>():\n    run_recipe(\n        \"tinker_cookbook.recipes.<recipe_name>.train\",\n        [\"behavior_if_log_dir_exists=delete\"],\n    )\n```\n\n`run_recipe()` automatically passes `max_steps=2` so the recipe runs 2 training steps and exits. See `tests/recipes/test_recipe_rlhf_pipeline.py` for the existing example.\n\n## Common pitfalls\n- RL learning rate must be **much lower** than SFT (1e-5 vs 2e-4)\n- Checkpoints flow between stages: SFT → RL policy init, RM → RL reward scoring\n- Use `checkpoint_utils.get_last_checkpoint()` to find checkpoints from previous stages\n- RM quality directly impacts RL — validate RM before running Stage 3\n- `group_size` in RL stage affects variance of reward estimates\n"
  },
  {
    "path": ".claude/skills/setup/SKILL.md",
    "content": "---\nname: setup\ndescription: Guide for installing Tinker, setting up the environment, getting an API key, and verifying everything works. Use when the user is getting started, setting up their environment, or troubleshooting installation issues.\n---\n\n# Setup & Installation\n\nGet Tinker and tinker-cookbook running from scratch.\n\n## Reference\n\n- `docs/install.mdx` — Official installation guide\n- `CONTRIBUTING.md` — Development setup\n- `README.md` — Project overview\n\n## Step 1: Sign up and get an API key\n\n1. Sign up at [https://auth.thinkingmachines.ai/sign-up](https://auth.thinkingmachines.ai/sign-up)\n2. Create an API key from the [console](https://tinker-console.thinkingmachines.ai)\n3. Export it:\n```bash\nexport TINKER_API_KEY=<your-key>\n```\n\nAdd to your shell profile (`.bashrc`, `.zshrc`) for persistence.\n\n## Step 2: Install Tinker SDK\n\n```bash\npip install tinker\n```\n\nThis gives you:\n- **Python SDK** — `TrainingClient`, `SamplingClient`, low-level training/sampling APIs\n- **Tinker CLI** — `tinker` or `python -m tinker` for management tasks\n\n## Step 3: Install tinker-cookbook\n\n```bash\ngit clone https://github.com/thinking-machines-lab/tinker-cookbook.git\ncd tinker-cookbook\npip install -e .\n```\n\nOr with dev dependencies (for contributing):\n```bash\nuv sync --extra dev\npre-commit install\n```\n\n## Step 4: Verify installation\n\n```python\nimport tinker\nservice_client = tinker.ServiceClient()\ntraining_client = service_client.create_lora_training_client(\n    base_model=\"meta-llama/Llama-3.2-1B\", rank=32,\n)\ninfo = training_client.get_info()\nprint(info)  # Should print model info\n```\n\n## Step 5: Run a minimal example\n\n```bash\n# Supervised learning\npython -m tinker_cookbook.recipes.sl_basic\n\n# Reinforcement learning\npython -m tinker_cookbook.recipes.rl_basic\n```\n\n## Environment variables\n\n| Variable | Purpose |\n|----------|---------|\n| `TINKER_API_KEY` | Required — authenticates with Tinker service |\n| `HF_TOKEN` | Optional — access gated HuggingFace models (Llama, etc.) |\n| `HF_TRUST_REMOTE_CODE` | Optional — allow custom tokenizer code |\n| `WANDB_API_KEY` | Optional — log to Weights & Biases |\n\n## Common issues\n\n- **`TINKER_API_KEY not set`**: Export the key in your shell or `.env` file\n- **Tokenizer download fails**: Set `HF_TOKEN` for gated models (e.g., Llama)\n- **Import errors**: Ensure `pip install -e .` was run from the repo root\n- **`uv` not found**: Install with `curl -LsSf https://astral.sh/uv/install.sh | sh`\n"
  },
  {
    "path": ".claude/skills/sft/SKILL.md",
    "content": "---\nname: sft\ndescription: Set up and run supervised fine-tuning (SFT) on instruction or chat datasets using the Tinker API. Use when the user wants to do instruction tuning, chat fine-tuning, or supervised learning.\nargument-hint: \"[model-name] [dataset]\"\n---\n\n# Supervised Fine-Tuning (SFT)\n\nHelp the user set up and run supervised fine-tuning using the Tinker API.\n\n## Step 1: Understand the request\n\nAsk the user (if not already specified):\n- **Model**: Which base model to fine-tune (e.g., `meta-llama/Llama-3.1-8B`, `Qwen/Qwen3-8B`). See `docs/model-lineup.mdx` for available models.\n- **Dataset**: What data to train on — built-in datasets (NoRobots, Tulu3) or custom JSONL file.\n- **Goal**: General instruction tuning, domain-specific fine-tuning, or chat quality improvement.\n\n## Step 2: Reference existing recipes\n\nRead these files for patterns and conventions:\n- `tinker_cookbook/recipes/sl_basic.py` — Minimal SFT example\n- `tinker_cookbook/recipes/chat_sl/train.py` — Full-featured chat SFT with eval\n- `tinker_cookbook/supervised/train.py` — Core training loop\n- `tinker_cookbook/supervised/data.py` — Dataset construction helpers\n- `docs/supervised-learning/sl-basic.mdx` — Getting started guide\n- `docs/supervised-learning/sl-hyperparams.mdx` — Learning rate and batch size guidance\n\n## Step 3: Configure the training run\n\nKey configuration decisions:\n\n### Renderer\nMatch renderer to model family using `model_info.get_recommended_renderer_name(model_name)`. Never hardcode renderer names.\n\n### Learning Rate\n- Use `hyperparam_utils.get_lr(model_name)` for recommended LR\n- LoRA fine-tuning typically needs ~10x higher LR than full fine-tuning (e.g., 2e-4 for LoRA vs 2e-5 for full)\n\n### TrainOnWhat\n- `TrainOnWhat.ALL_ASSISTANT_MESSAGES` — Train on all assistant turns (most common)\n- `TrainOnWhat.LAST_ASSISTANT_MESSAGE` — Train only on final assistant response\n- `TrainOnWhat.EVERYTHING` — Train on entire conversation including user messages\n\n### Dataset\n- Built-in: `NoRobotsBuilder`, `Tulu3Builder`\n- Custom JSONL: Use `FromConversationFileBuilder(common_config=..., file_path=\"path/to/data.jsonl\")`\n- Format: Same as `tinker_cookbook/example_data/conversations.jsonl`\n\n### Batch Size & Epochs\n- `batch_size`: Number of tokens per training batch (default: 128 for basic, scale up as needed)\n- `num_epochs`: Number of passes through the dataset\n- `eval_every`: Evaluate every N batches\n\n## Step 4: Write the training script\n\nFollow the pattern from `sl_basic.py`:\n\n```python\nimport asyncio\nimport chz\nimport sys\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.supervised import train\nfrom tinker_cookbook.renderers import TrainOnWhat\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilderCommonConfig\n\ndef build_config_blueprint() -> chz.Blueprint[train.Config]:\n    model_name = \"meta-llama/Llama-3.1-8B\"\n    renderer_name = model_info.get_recommended_renderer_name(model_name)\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=model_name,\n        renderer_name=renderer_name,\n        max_length=32768,\n        batch_size=128,\n        train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES,\n    )\n    # Configure dataset builder here\n    dataset = ...\n\n    return chz.Blueprint(train.Config).apply({\n        \"log_path\": \"/tmp/tinker-examples/my_sft_run\",\n        \"model_name\": model_name,\n        \"renderer_name\": renderer_name,\n        \"dataset_builder\": dataset,\n        \"learning_rate\": 2e-4,\n        \"lr_schedule\": \"linear\",\n        \"num_epochs\": 1,\n        \"eval_every\": 8,\n    })\n\ndef main(config: train.Config):\n    cli_utils.check_log_dir(config.log_path, behavior_if_exists=\"ask\")\n    asyncio.run(train.main(config))\n\nif __name__ == \"__main__\":\n    blueprint = build_config_blueprint()\n    blueprint.make_from_argv(sys.argv[1:])\n    main(blueprint.make())\n```\n\n## Step 5: Run and iterate\n\n```bash\npython -m tinker_cookbook.recipes.<recipe_name>\n```\n\nOverride parameters from CLI: `python -m tinker_cookbook.recipes.<recipe_name> learning_rate=1e-4 batch_size=256`\n\n## Step 6: Add tests\n\nIf you created a new recipe, add a smoke test so CI catches regressions:\n\n```python\n# tests/recipes/test_recipe_<name>.py\nimport pytest\nfrom tests.helpers import run_recipe\n\n@pytest.mark.integration\ndef test_<recipe_name>():\n    run_recipe(\n        \"tinker_cookbook.recipes.<recipe_name>.train\",\n        [\"behavior_if_log_dir_exists=delete\", \"batch_size=16\"],\n    )\n```\n\n`run_recipe()` automatically passes `max_steps=2` to the recipe, so it runs 2 training steps and exits cleanly. Tests are auto-discovered by CI and run daily. For unit-testable components (dataset processing, custom logic), add `*_test.py` files next to the source code.\n\n## Step 7: Export weights (optional)\n\nAfter training, export weights using the `tinker_cookbook.weights` API:\n\n```python\nfrom tinker_cookbook import weights\n\nadapter_dir = weights.download(tinker_path=\"tinker://run-id/sampler_weights/final\", output_dir=\"./adapter\")\nweights.build_hf_model(base_model=\"meta-llama/Llama-3.1-8B\", adapter_path=adapter_dir, output_path=\"./model\")\nweights.publish_to_hf_hub(model_path=\"./model\", repo_id=\"user/my-finetuned-model\")\n```\n\n## Common pitfalls\n- Always use `model_info.get_recommended_renderer_name()` — never hardcode renderer\n- Use `cli_utils.check_log_dir()` to avoid clobbering previous runs\n- For custom datasets, ensure JSONL matches the conversation format in `example_data/conversations.jsonl`\n- LR too high causes instability; LR too low wastes compute\n"
  },
  {
    "path": ".claude/skills/tinker-cli/SKILL.md",
    "content": "---\nname: tinker-cli\ndescription: Guide for the Tinker CLI — managing training runs, checkpoints, downloading weights, and publishing to HuggingFace. Use when the user asks about CLI commands, listing runs, managing checkpoints from the terminal, or uploading to HF Hub.\n---\n\n# Tinker CLI\n\nThe `tinker` CLI is installed with the Tinker Python SDK. It provides commands for managing training runs and checkpoints from the terminal.\n\nRequires `TINKER_API_KEY` environment variable to be set.\n\n## Global options\n\n```bash\ntinker --format table   # Rich table output (default)\ntinker --format json    # JSON output (for scripting)\n```\n\n## Training runs\n\n```bash\n# List recent training runs\ntinker run list\ntinker run list --limit 50\n\n# Show details for a specific run\ntinker run info <RUN_ID>\n\n# Custom columns\ntinker run list --columns id,model,lora,updated,status,checkpoint\n```\n\nAvailable columns: `id`, `model`, `owner`, `lora`, `updated`, `status`, `checkpoint`, `checkpoint_time`.\n\n## Checkpoints\n\n### List and inspect\n\n```bash\n# List checkpoints for a specific run\ntinker checkpoint list --run-id <RUN_ID>\n\n# List all your checkpoints across runs\ntinker checkpoint list\ntinker checkpoint list --limit 50\n\n# Show checkpoint details\ntinker checkpoint info <TINKER_PATH>\n```\n\n### Download\n\n```bash\n# Download and extract a checkpoint\ntinker checkpoint download <TINKER_PATH>\ntinker checkpoint download <TINKER_PATH> --output ./my-adapter\ntinker checkpoint download <TINKER_PATH> --force  # Overwrite existing\n```\n\n### Visibility\n\n```bash\n# Make a checkpoint publicly accessible\ntinker checkpoint publish <TINKER_PATH>\n\n# Make a checkpoint private\ntinker checkpoint unpublish <TINKER_PATH>\n```\n\n### TTL (expiration)\n\n```bash\n# Set checkpoint to expire in 24 hours\ntinker checkpoint set-ttl <TINKER_PATH> --ttl 86400\n\n# Remove expiration (keep indefinitely)\ntinker checkpoint set-ttl <TINKER_PATH> --remove\n```\n\n### Delete\n\n```bash\n# Delete checkpoints (with confirmation prompt)\ntinker checkpoint delete <TINKER_PATH>\n\n# Delete without confirmation\ntinker checkpoint delete <TINKER_PATH> -y\n\n# Delete multiple\ntinker checkpoint delete <PATH1> <PATH2> <PATH3>\n```\n\n### Upload to HuggingFace Hub\n\n```bash\n# Push checkpoint to HuggingFace\ntinker checkpoint push-hf <TINKER_PATH> --repo user/my-model\n\n# Push as public repo\ntinker checkpoint push-hf <TINKER_PATH> --repo user/my-model --public\n\n# Advanced options\ntinker checkpoint push-hf <TINKER_PATH> \\\n    --repo user/my-model \\\n    --revision main \\\n    --commit-message \"Upload fine-tuned model\" \\\n    --create-pr \\\n    --no-model-card\n```\n\nOptions: `--repo`, `--public`, `--revision`, `--commit-message`, `--create-pr`, `--allow-pattern`, `--ignore-pattern`, `--no-model-card`.\n\n## Version\n\n```bash\ntinker version   # e.g. \"tinker 0.15.0\"\n```\n\n## Common patterns\n\n### Script-friendly output\n```bash\n# Get checkpoint paths as JSON for scripting\ntinker checkpoint list --format json | jq '.[].tinker_path'\n\n# Get run IDs\ntinker run list --format json | jq '.[].id'\n```\n\n### Typical workflow\n```bash\n# 1. Find your training run\ntinker run list\n\n# 2. List checkpoints for that run\ntinker checkpoint list --run-id <RUN_ID>\n\n# 3. Download the final checkpoint\ntinker checkpoint download tinker://<RUN_ID>/sampler_weights/final -o ./adapter\n\n# 4. Or push directly to HuggingFace\ntinker checkpoint push-hf tinker://<RUN_ID>/sampler_weights/final --repo user/my-model\n```\n\n## Common pitfalls\n- `TINKER_API_KEY` must be set — the CLI reads it from the environment\n- Checkpoint paths use the format `tinker://<run-id>/<type>/<checkpoint-id>`\n- `push-hf` uploads the raw checkpoint — for merged HF models, use `weights.build_hf_model()` in Python first (see `/weights` skill)\n- `delete` is permanent and irreversible — use `-y` flag carefully\n"
  },
  {
    "path": ".claude/skills/tinker-sdk/SKILL.md",
    "content": "---\nname: tinker-sdk\ndescription: Guide for using the Tinker Python SDK APIs — ServiceClient, TrainingClient, SamplingClient, RestClient, forward_backward, optim_step, sampling, and async patterns. Use when the user asks about Tinker API basics, how to call training/sampling, or how the SDK works.\n---\n\n# Tinker Python SDK\n\nHelp the user understand and use the core Tinker SDK APIs.\n\n## Reference docs\n\nRead these for authoritative API documentation:\n- `docs/api-reference/serviceclient.md` — ServiceClient API\n- `docs/api-reference/trainingclient.md` — TrainingClient API\n- `docs/api-reference/samplingclient.md` — SamplingClient API\n- `docs/api-reference/restclient.md` — RestClient API\n- `docs/api-reference/types.md` — All SDK types\n- `docs/training-sampling.mdx` — Starter walkthrough\n- `docs/async.mdx` — Sync/async patterns, futures\n- `docs/losses.mdx` — Loss functions\n- `docs/under-the-hood.mdx` — Clock cycles, worker pools\n\n## ServiceClient (entry point)\n\n`ServiceClient` is the main entry point. All other clients are created from it.\n\n```python\nfrom tinker import ServiceClient\n\nsvc = ServiceClient(user_metadata={\"experiment\": \"v1\"}, project_id=\"my-project\")\n\n# Create a new LoRA training client\ntc = svc.create_lora_training_client(\n    base_model=\"Qwen/Qwen3-8B\",\n    rank=32,\n    seed=None,\n    train_mlp=True,\n    train_attn=True,\n    train_unembed=True,\n)\n\n# Resume from a training checkpoint\ntc = svc.create_training_client_from_state(path=\"tinker://...\")              # weights only\ntc = svc.create_training_client_from_state_with_optimizer(path=\"tinker://...\") # weights + optimizer\n\n# Create a sampling client\nsc = svc.create_sampling_client(model_path=\"tinker://...\", base_model=None, retry_config=None)\n\n# Create a REST client for checkpoint/run management\nrest = svc.create_rest_client()\n\n# Query available models\ncaps = svc.get_server_capabilities()  # returns GetServerCapabilitiesResponse\n```\n\nAll creation methods have `_async` variants.\n\n## TrainingClient\n\n```python\n# Forward/backward pass (compute loss + gradients)\nresult = tc.forward_backward(data=[datum1, datum2], loss_fn=\"cross_entropy\")\n\n# Forward-only pass (compute loss, no gradients — useful for eval)\nresult = tc.forward(data=[datum1, datum2], loss_fn=\"cross_entropy\")\n\n# Custom loss function\nresult = tc.forward_backward_custom(data=[datum1, datum2], loss_fn=my_custom_loss_fn)\n\n# Optimizer step\ntc.optim_step(adam_params=AdamParams(learning_rate=2e-4))\n\n# Checkpointing\ntc.save_state(name=\"step_100\", ttl_seconds=None)                # Full state (resumable)\ntc.save_weights_for_sampler(name=\"step_100_sampler\", ttl_seconds=None)  # Sampler-only\n\n# Save + get SamplingClient in one call\nsc = tc.save_weights_and_get_sampling_client(name=\"step_100\")\n\n# Load checkpoint\ntc.load_state(path=\"tinker://...\")\ntc.load_state_with_optimizer(path=\"tinker://...\")\n\n# Metadata\ninfo = tc.get_info()          # GetInfoResponse (model name, LoRA rank, tokenizer)\ntokenizer = tc.get_tokenizer()  # HuggingFace tokenizer\n```\n\n### Loss functions\n- `\"cross_entropy\"` — Standard SL loss\n- `\"importance_sampling\"` — On-policy RL (default for GRPO)\n- `\"ppo\"` — Proximal Policy Optimization\n- `\"cispo\"` — Conservative Importance Sampling PPO\n- `\"dro\"` — Distributionally Robust Optimization\n\nSee `docs/losses.mdx` for details and `loss_fn_config` parameters.\n\n### Async variants\n\nAll methods have `_async` variants that return `APIFuture`:\n```python\nfb_future = tc.forward_backward_async(data=data, loss_fn=\"cross_entropy\")\noptim_future = tc.optim_step_async(adam_params=adam_params)\n# Do other work...\nfb_result = fb_future.result()\noptim_result = optim_future.result()\n```\n\n**Key pattern:** Submit `forward_backward_async` and `optim_step_async` back-to-back before awaiting — this overlaps GPU computation with data preparation.\n\n## SamplingClient\n\n```python\nfrom tinker import SamplingParams\n\nsc = tc.save_weights_and_get_sampling_client(name=\"step_100\")\n\nresponse = sc.sample(\n    prompt=model_input,\n    num_samples=4,\n    sampling_params=SamplingParams(max_tokens=256, temperature=1.0),\n    include_prompt_logprobs=False,   # Set True to get per-token prompt logprobs\n    topk_prompt_logprobs=0,          # Top-K logprobs per prompt token (0 = disabled)\n)\n\nfor seq in response.sequences:\n    print(seq.tokens, seq.logprobs, seq.stop_reason)\n\n# Get logprobs for existing tokens (no generation)\nlogprobs_response = sc.compute_logprobs(prompt=model_input)\n\n# Metadata\nbase_model = sc.get_base_model()    # Base model name string\ntokenizer = sc.get_tokenizer()      # HuggingFace tokenizer\n```\n\nSamplingClient is picklable for multiprocessing use.\n\n**Important:** Always create a **new** SamplingClient after saving weights. A stale client points at old weights.\n\n## RestClient\n\nFor managing training runs and checkpoints. See also the `/tinker-cli` skill for CLI equivalents.\n\n```python\nrest = svc.create_rest_client()\n\n# Training runs\nruns = rest.list_training_runs(limit=20, offset=0, access_scope=\"owned\")\nrun = rest.get_training_run(training_run_id=\"...\")\nrun = rest.get_training_run_by_tinker_path(tinker_path=\"tinker://...\")\n\n# Checkpoints\ncheckpoints = rest.list_checkpoints(training_run_id=\"...\")\nall_checkpoints = rest.list_user_checkpoints(limit=100, offset=0)\nrest.delete_checkpoint(training_run_id=\"...\", checkpoint_id=\"...\")\nrest.delete_checkpoint_from_tinker_path(tinker_path=\"tinker://...\")\n\n# Checkpoint visibility\nrest.publish_checkpoint_from_tinker_path(tinker_path=\"tinker://...\")    # Make public\nrest.unpublish_checkpoint_from_tinker_path(tinker_path=\"tinker://...\")  # Make private\n\n# Checkpoint TTL\nrest.set_checkpoint_ttl_from_tinker_path(tinker_path=\"tinker://...\", ttl_seconds=86400)\n\n# Download URL\nurl_resp = rest.get_checkpoint_archive_url_from_tinker_path(tinker_path=\"tinker://...\")\n\n# Checkpoint metadata\ninfo = rest.get_weights_info_by_tinker_path(tinker_path=\"tinker://...\")\n```\n\nAll RestClient methods have `_async` variants.\n\n## Retry behavior\n\nThe Tinker SDK retries **all** HTTP API calls automatically (10 attempts, exponential backoff with jitter). Retried request types: timeouts (408), lock conflicts (409), rate limits (429), server errors (500+), and connection failures. The SDK respects `Retry-After` headers and attaches idempotency keys to non-GET requests.\n\nClient errors (400, 401, 403, 404, 422) are **not** retried — these raise immediately (e.g., `tinker.BadRequestError`, `tinker.AuthenticationError`).\n\nOverride via `max_retries` on client creation:\n```python\nsvc = tinker.ServiceClient(max_retries=3)   # reduce retries\nsvc = tinker.ServiceClient(max_retries=0)   # disable retries\n```\n\n**Do not** add retry wrappers around Tinker API calls in training loops — the SDK handles this. Enable retry logging with `logging.getLogger(\"tinker\").setLevel(logging.DEBUG)`.\n\n## Common pitfalls\n- **Use ServiceClient** to create clients — `TrainingClient` and `SamplingClient` cannot be constructed directly\n- Always await futures before submitting new forward_backward calls\n- Submit `forward_backward_async` + `optim_step_async` back-to-back before awaiting\n- Create a **new** SamplingClient after saving weights (sampler desync)\n- Use `save_state` for resumable checkpoints, `save_weights_for_sampler` for sampling-only\n- `forward()` computes loss without gradients — use for eval, not training\n"
  },
  {
    "path": ".claude/skills/tinker-types/SKILL.md",
    "content": "---\nname: tinker-types\ndescription: Reference for Tinker SDK types — Datum, ModelInput, TensorData, SamplingParams, response types, error types, and helper functions. Use when the user needs to build training data, construct model inputs, understand response objects, or handle errors.\n---\n\n# Tinker SDK Types\n\nQuick reference for the core types used throughout the Tinker SDK and cookbook.\n\n## Reference\n\nRead `docs/api-reference/types.md` for the complete type reference.\n\n## Core data types\n\n### Type hierarchy\n```\nDatum\n├── model_input: ModelInput (list of chunks)\n│   ├── EncodedTextChunk (token IDs)\n│   └── ImageChunk (vision inputs)\n└── loss_fn_inputs: dict[str, TensorData]\n    └── TensorData (numpy/torch wrapper)\n```\n\n### ModelInput\n```python\nfrom tinker import ModelInput\n\nmi = ModelInput.from_ints([1, 2, 3, 4, 5])  # From token list\ntokens = mi.to_ints()                        # Back to list\nlength = mi.length                           # Token count (property)\nmi2 = mi.append(chunk)                       # Append a chunk\nmi3 = mi.append_int(42)                      # Append a single token\nmi_empty = ModelInput.empty()                # Empty input\n```\n\n### TensorData\n```python\nfrom tinker import TensorData\n\ntd = TensorData.from_numpy(np.array([1.0, 0.0, 1.0]))  # From numpy\ntd = TensorData.from_torch(torch.tensor([1.0, 0.0]))    # From torch\narr = td.to_numpy()                                       # Back to numpy\ntensor = td.to_torch()                                    # Back to torch\nlst = td.tolist()                                         # Back to list\n# Fields: data (flat list), dtype (\"int64\"|\"float32\"), shape (optional)\n```\n\n### Datum\n```python\nfrom tinker import Datum, ModelInput, TensorData\n\ndatum = Datum(\n    model_input=ModelInput.from_ints(tokens),\n    loss_fn_inputs={\"weights\": TensorData.from_numpy(weights_array)},\n)\n```\n\n## Configuration types\n\n### SamplingParams\n```python\nfrom tinker import SamplingParams\n\nparams = SamplingParams(\n    max_tokens=256,        # Max generation length\n    temperature=1.0,       # Sampling temperature\n    top_k=50,              # Top-K sampling (-1 = no limit)\n    top_p=0.95,            # Nucleus sampling\n    stop=[\"<|eot_id|>\"],   # Stop sequences (strings or token IDs)\n    seed=42,               # Reproducible seed\n)\n```\n\n### AdamParams\n```python\nfrom tinker import AdamParams\n\nadam = AdamParams(\n    learning_rate=2e-4,\n    beta1=0.9,             # Gradient moving average\n    beta2=0.95,            # Gradient squared moving average\n    eps=1e-12,             # Numerical stability\n    weight_decay=0.0,      # Decoupled weight decay\n    grad_clip_norm=1.0,    # Global gradient norm clipping (0.0 = disabled)\n)\n```\n\n### LoraConfig\n```python\nfrom tinker import LoraConfig\n\nconfig = LoraConfig(\n    rank=32,               # LoRA rank\n    seed=None,             # Initialization seed\n    train_mlp=True,        # Train MLP layers\n    train_attn=True,       # Train attention layers\n    train_unembed=True,    # Train unembedding layer\n)\n```\n\n## Response types\n\n### ForwardBackwardOutput\nReturned by `forward_backward()` and `forward()`:\n```python\nresult = tc.forward_backward(data=batch, loss_fn=\"cross_entropy\")\nresult.metrics              # dict[str, float] — training metrics (includes loss)\nresult.loss_fn_outputs      # list[LossFnOutput] — per-sample outputs\nresult.loss_fn_output_type  # str — loss output class name\n```\n\n### SampleResponse / SampledSequence\nReturned by `sample()`:\n```python\nresponse = sc.sample(prompt=mi, num_samples=4, sampling_params=params)\nresponse.sequences                # list[SampledSequence]\nresponse.prompt_logprobs          # Optional[list[Optional[float]]] — per-prompt-token logprobs\nresponse.topk_prompt_logprobs     # Optional[list[Optional[list[tuple[int, float]]]]] — top-K\n\nfor seq in response.sequences:\n    seq.tokens       # list[int] — generated token IDs\n    seq.logprobs     # Optional[list[float]] — per-token logprobs\n    seq.stop_reason  # StopReason: \"length\" | \"stop\"\n```\n\n### Other response types\n- `OptimStepResponse` — confirms parameter update\n- `SaveWeightsResponse` — `path: str` (tinker:// path to saved weights)\n- `LoadWeightsResponse` — confirms loaded weights\n- `GetInfoResponse` — `model_data: ModelData` (model_name, lora_rank, tokenizer_id)\n- `GetServerCapabilitiesResponse` — `supported_models: list[SupportedModel]`\n- `WeightsInfoResponse` — `base_model`, `lora_rank`, `is_lora`, `train_mlp`, `train_attn`, `train_unembed`\n\n## Checkpoint and run types\n\n```python\nfrom tinker import TrainingRun, Checkpoint, CheckpointType, ParsedCheckpointTinkerPath\n\n# TrainingRun — metadata about a training run\nrun.training_run_id    # str\nrun.base_model         # str\nrun.is_lora            # bool\nrun.lora_rank          # Optional[int]\nrun.last_checkpoint    # Optional[Checkpoint]\nrun.user_metadata      # Optional[dict[str, str]]\n\n# Checkpoint — metadata about a saved checkpoint\nckpt.checkpoint_id     # str\nckpt.checkpoint_type   # CheckpointType: \"training\" | \"sampler\"\nckpt.tinker_path       # str (tinker:// path)\nckpt.size_bytes        # Optional[int]\nckpt.public            # bool\nckpt.expires_at        # Optional[datetime]\n\n# Parse a tinker:// path\nparsed = ParsedCheckpointTinkerPath.from_tinker_path(\"tinker://run-id/weights/ckpt-id\")\nparsed.training_run_id  # str\nparsed.checkpoint_type  # CheckpointType\nparsed.checkpoint_id    # str\n```\n\n## Error types\n\nAll exceptions inherit from `tinker.TinkerError`:\n- **`APIError`** → **`APIStatusError`**: `BadRequestError` (400), `AuthenticationError` (401), `PermissionDeniedError` (403), `NotFoundError` (404), `ConflictError` (409), `UnprocessableEntityError` (422), `RateLimitError` (429), `InternalServerError` (500+)\n- **`APIConnectionError`**, **`APITimeoutError`**, **`APIResponseValidationError`**\n- **`RequestFailedError`** — async request failure with error category\n\n## Cookbook helper functions\n\nUse these instead of manual Datum construction:\n- `tinker_cookbook.supervised.data.conversation_to_datum(messages, renderer, max_length, train_on_what)` — full SL pipeline\n- `tinker_cookbook.supervised.common.datum_from_model_input_weights(model_input, weights, max_length)` — from ModelInput + weights\n- `renderer.build_supervised_example(messages)` — returns `(ModelInput, weights)`\n\n## Common pitfalls\n- Use helper functions instead of manual dict construction for Datum\n- `TensorData` wraps arrays — don't pass raw numpy/torch directly to `loss_fn_inputs`\n- `ModelInput.from_ints()` expects a flat list of integers, not nested lists\n- `ModelInput.length` is a property, not a method\n- Handle `tinker.RateLimitError` in production code with exponential backoff\n"
  },
  {
    "path": ".claude/skills/weights/SKILL.md",
    "content": "---\nname: weights\ndescription: Guide for the weight lifecycle — downloading trained weights from Tinker, merging LoRA adapters into HuggingFace models, and publishing to HuggingFace Hub. Use when the user asks about exporting, downloading, merging, or publishing trained model weights.\n---\n\n# Weight Lifecycle\n\nThe `tinker_cookbook.weights` subpackage provides a standard pipeline for trained weight management: **download → build → publish**.\n\n## Reference\n\nRead these for details:\n- `tinker_cookbook/weights/__init__.py` — API overview and workflow example\n- `tinker_cookbook/weights/_download.py` — Download implementation\n- `tinker_cookbook/weights/_export.py` — LoRA merge implementation\n- `tinker_cookbook/weights/_publish.py` — HuggingFace Hub publish\n- `docs/download-weights.mdx` — Download guide\n- `docs/publish-weights.mdx` — Publishing guide\n- `docs/save-load.mdx` — Checkpointing (save_weights_for_sampler vs save_state)\n\n## Full workflow\n\n```python\nfrom tinker_cookbook import weights\n\n# Step 1: Download adapter from Tinker\nadapter_dir = weights.download(\n    tinker_path=\"tinker://run-id/sampler_weights/final\",\n    output_dir=\"./adapter\",\n)\n\n# Step 2: Merge LoRA adapter into base model\nweights.build_hf_model(\n    base_model=\"Qwen/Qwen3.5-35B-A3B\",\n    adapter_path=adapter_dir,\n    output_path=\"./model\",\n    dtype=\"bfloat16\",  # or \"float16\", \"float32\"\n)\n\n# Step 3: Publish to HuggingFace Hub\nurl = weights.publish_to_hf_hub(\n    model_path=\"./model\",\n    repo_id=\"user/my-finetuned-model\",\n    private=True,\n)\n```\n\n## API reference\n\n### `weights.download()`\nDownloads and extracts a checkpoint archive from Tinker.\n\n```python\nadapter_dir = weights.download(\n    tinker_path=\"tinker://run-id/sampler_weights/final\",  # Tinker checkpoint path\n    output_dir=\"./adapter\",      # Local directory to extract to\n    base_url=None,               # Optional custom Tinker API URL\n)\n# Returns: path to extracted directory\n```\n\n### `weights.build_hf_model()`\nMerges a LoRA adapter into a base model, producing a full HuggingFace model.\n\n```python\nweights.build_hf_model(\n    base_model=\"Qwen/Qwen3-8B\",     # HF model name or local path\n    adapter_path=\"./adapter\",        # Directory with adapter_model.safetensors\n    output_path=\"./model\",           # Where to save merged model\n    dtype=\"bfloat16\",                # Weight dtype\n    trust_remote_code=None,          # Override HF_TRUST_REMOTE_CODE\n)\n```\n\n### `weights.publish_to_hf_hub()`\nPushes a local model directory to HuggingFace Hub.\n\n```python\nurl = weights.publish_to_hf_hub(\n    model_path=\"./model\",                    # Local model directory\n    repo_id=\"user/my-finetuned-model\",       # HF repo ID\n    private=True,                            # Private repo\n    token=None,                              # HF token (uses HF_TOKEN env var if None)\n)\n# Returns: URL to published repo\n```\n\n### `weights.build_lora_adapter()` (not yet implemented)\nConvert Tinker LoRA adapter to standard format for vLLM/SGLang. Currently raises `NotImplementedError` — use `build_hf_model()` instead.\n\n## Checkpoint types (during training)\n\nDuring training, there are two types of checkpoints:\n\n- **`save_state()`** — Full state (weights + optimizer). Used for **resuming** training.\n- **`save_weights_for_sampler()`** — Weights only. Used for **sampling** and **export**.\n\nThe `weights.download()` function works with sampler weights (`save_weights_for_sampler` checkpoints).\n\n## Common pitfalls\n- `download()` expects a `tinker://` path from `save_weights_for_sampler`, not `save_state`\n- `build_hf_model()` requires the base model to be downloadable from HuggingFace\n- Set `HF_TOKEN` environment variable for private models and publishing\n- `dtype=\"bfloat16\"` is recommended for most models\n"
  },
  {
    "path": ".github/workflows/claude-review.yml",
    "content": "name: Claude Code\n\npermissions:\n  contents: write        # allow Claude to edit files & push commits\n  pull-requests: write   # allow PR comments/reviews & PR creation\n  issues: write          # allow issue comments & labels\n  actions: read\n\non:\n  # Respond to @claude mentions in PRs & issues (trusted users only)\n  issue_comment:\n    types: [created]\n  pull_request_review_comment:\n    types: [created]\n  issues:\n    types: [opened]\n\nenv:\n  CLAUDE_ARGS: >\n    --model claude-opus-4-5-20251101\n    --max-turns 50\n    --allowedTools \"Read\" \"Write\" \"Edit\" \"MultiEdit\"\n    \"Glob\" \"Grep\" \"LS\"\n    \"Bash(git:*)\" \"Bash(gh:*)\"\n    \"mcp__github_inline_comment__create_inline_comment\"\n\njobs:\n  claude_mention:\n    name: Respond to @claude in issues & PRs\n    runs-on: ubuntu-latest\n    # Only trusted users (OWNER, MEMBER, COLLABORATOR) can trigger Claude\n    if: >\n      (github.event_name == 'issue_comment' &&\n       contains(github.event.comment.body, '@claude') &&\n       contains(fromJSON('[\"OWNER\", \"MEMBER\", \"COLLABORATOR\"]'), github.event.comment.author_association)) ||\n      (github.event_name == 'pull_request_review_comment' &&\n       contains(github.event.comment.body, '@claude') &&\n       contains(fromJSON('[\"OWNER\", \"MEMBER\", \"COLLABORATOR\"]'), github.event.comment.author_association)) ||\n      (github.event_name == 'issues' &&\n       contains(github.event.issue.body, '@claude') &&\n       contains(fromJSON('[\"OWNER\", \"MEMBER\", \"COLLABORATOR\"]'), github.event.issue.author_association))\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v5\n        with:\n          fetch-depth: 1\n\n      - name: Claude on @mention\n        uses: anthropics/claude-code-action@v1\n        with:\n          anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}\n          github_token: ${{ secrets.GITHUB_TOKEN }}\n          claude_args: ${{ env.CLAUDE_ARGS }}\n"
  },
  {
    "path": ".github/workflows/downstream-compat.yaml",
    "content": "name: downstream-compat\n\non:\n  workflow_dispatch:\n  push:\n    branches: [main]\n  pull_request:\n\njobs:\n  downstream-compat:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: checkout\n        uses: actions/checkout@v4\n\n      - name: install-uv\n        uses: astral-sh/setup-uv@v6\n        with:\n          enable-cache: true\n\n      - name: venv\n        run: uv venv && uv sync --all-extras\n\n      - name: downstream compat tests\n        run: uv run pytest tests/downstream_compat/ -v\n        env:\n          HF_TOKEN: ${{ secrets.HF_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/nightly.yaml",
    "content": "name: nightly\n\non:\n  workflow_run:\n    workflows: [\"smoke-test-recipes\"]\n    types: [completed]\n  workflow_dispatch:\n\npermissions:\n  contents: write  # needed to create/delete releases\n\njobs:\n  build-and-release:\n    runs-on: ubuntu-latest\n    # Only run on the upstream repo (forks lack secrets and shouldn't publish releases)\n    # and only if: manually triggered, or smoke tests passed on schedule\n    if: >\n      github.repository == 'thinking-machines-lab/tinker-cookbook' &&\n      (github.event_name == 'workflow_dispatch' ||\n       (github.event.workflow_run.conclusion == 'success' &&\n        github.event.workflow_run.event == 'schedule'))\n    steps:\n      - name: checkout\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n\n      - name: install-uv\n        uses: astral-sh/setup-uv@v6\n        with:\n          enable-cache: true\n\n      - name: build\n        run: uv build\n\n      - name: smoke test\n        run: |\n          uv run python -c \"import tinker_cookbook; print(f'Version: {tinker_cookbook.__version__}')\"\n\n      - name: get version\n        id: version\n        run: |\n          VERSION=$(uv run python -c \"import tinker_cookbook; print(tinker_cookbook.__version__)\")\n          echo \"version=$VERSION\" >> \"$GITHUB_OUTPUT\"\n\n      - name: upload artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: tinker-cookbook-nightly\n          path: dist/\n          retention-days: 7\n\n      - name: delete existing nightly release\n        run: gh release delete nightly --yes --cleanup-tag || true\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: create nightly release\n        run: |\n          VERSION=\"${{ steps.version.outputs.version }}\"\n          REPO=\"${{ github.repository }}\"\n          SHORT_SHA=\"${GITHUB_SHA::8}\"\n          {\n            echo \"Automated nightly build from \\`main\\` at $(date -u '+%Y-%m-%d %H:%M UTC').\"\n            echo \"\"\n            echo \"**Version:** \\`${VERSION}\\`\"\n            echo \"**Commit:** [\\`${SHORT_SHA}\\`](https://github.com/${REPO}/commit/${GITHUB_SHA})\"\n            echo \"\"\n            echo \"### Install\"\n            echo \"\\`\\`\\`bash\"\n            echo \"pip install 'tinker_cookbook @ https://github.com/${REPO}/releases/download/nightly/tinker_cookbook-${VERSION}-py3-none-any.whl'\"\n            echo \"\\`\\`\\`\"\n          } > /tmp/release-notes.md\n          gh release create nightly dist/* \\\n            --prerelease \\\n            --title \"Nightly Build (${VERSION})\" \\\n            --notes-file /tmp/release-notes.md\n        env:\n          GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/pre-commit.yaml",
    "content": "name: pre-commit\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n\njobs:\n  pre-commit:\n    runs-on: ubuntu-latest\n\n    steps:\n    - name: Checkout code\n      uses: actions/checkout@v4\n\n    - name: pre-commit\n      uses: pre-commit/action@v3.0.1\n"
  },
  {
    "path": ".github/workflows/publish-pypi.yaml",
    "content": "name: publish-pypi\n\non:\n  push:\n    tags: [\"v[0-9]+.[0-9]+.[0-9]+\"]  # only semver tags like v1.2.3\n  workflow_dispatch:\n    inputs:\n      tag:\n        description: \"Git tag to publish (e.g. v0.2.0). Must already exist.\"\n        required: true\n\njobs:\n  publish:\n    runs-on: ubuntu-latest\n    if: github.repository == 'thinking-machines-lab/tinker-cookbook'\n\n    steps:\n      - name: determine ref\n        id: ref\n        run: |\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            echo \"ref=${{ inputs.tag }}\" >> \"$GITHUB_OUTPUT\"\n          else\n            echo \"ref=${{ github.ref }}\" >> \"$GITHUB_OUTPUT\"\n          fi\n\n      - name: checkout\n        uses: actions/checkout@v4\n        with:\n          ref: ${{ steps.ref.outputs.ref }}\n          fetch-depth: 0  # hatch-vcs needs full history for version\n\n      - name: install-uv\n        uses: astral-sh/setup-uv@v6\n        with:\n          enable-cache: true\n\n      - name: build\n        run: uv build\n\n      - name: verify version matches tag\n        run: |\n          BUILT_VERSION=$(ls dist/*.tar.gz | sed 's/.*tinker_cookbook-//;s/\\.tar\\.gz//')\n          TAG_VERSION=\"${GITHUB_REF_NAME#v}\"\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            TAG_VERSION=\"${{ inputs.tag }}\"\n            TAG_VERSION=\"${TAG_VERSION#v}\"\n          fi\n          echo \"Built version: $BUILT_VERSION\"\n          echo \"Tag version: $TAG_VERSION\"\n          if [ \"$BUILT_VERSION\" != \"$TAG_VERSION\" ]; then\n            echo \"ERROR: Built version ($BUILT_VERSION) does not match tag ($TAG_VERSION)\"\n            exit 1\n          fi\n\n      - name: run smoke test\n        run: |\n          uv run python -c \"import tinker_cookbook; print(f'Version: {tinker_cookbook.__version__}')\"\n\n      - name: publish\n        run: uv publish --token=\"$PYPI_TOKEN\"\n        env:\n          PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/pyright.yaml",
    "content": "name: pyright\n\non:\n  push:\n    branches: [main]\n  pull_request:\n\njobs:\n  type-check:\n    runs-on: ubuntu-latest\n\n    strategy:\n      matrix:\n        transformers-version: [\"4.57.6\", \"5.3.0\"]\n\n    name: type-check (transformers ${{ matrix.transformers-version }})\n\n    steps:\n      - name: checkout\n        uses: actions/checkout@v4\n\n      - name: install-uv\n        uses: astral-sh/setup-uv@v6\n        with:\n          enable-cache: true\n\n      - name: venv\n        run: uv venv && uv sync --all-extras\n\n      - name: pin transformers\n        run: uv pip install transformers==${{ matrix.transformers-version }}\n\n      - name: pyright\n        run: uv run pyright tinker_cookbook\n"
  },
  {
    "path": ".github/workflows/pytest.yaml",
    "content": "name: pytest\n\non:\n  workflow_dispatch:\n  push:\n    branches: [main]\n  pull_request:\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n\n    strategy:\n      matrix:\n        transformers-version: [\"4.57.6\", \"5.3.0\"]\n\n    name: test (transformers ${{ matrix.transformers-version }})\n\n    steps:\n      - name: checkout\n        uses: actions/checkout@v4\n\n      - name: install-uv\n        uses: astral-sh/setup-uv@v6\n        with:\n          enable-cache: true\n\n      - name: venv\n        run: uv venv && uv sync --all-extras\n\n      - name: pin transformers\n        run: uv pip install transformers==${{ matrix.transformers-version }}\n\n      - name: pytest (unit)\n        run: uv run pytest tinker_cookbook/\n        env:\n          HF_TOKEN: ${{ secrets.HF_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/smoke-test-evals.yaml",
    "content": "name: smoke-test-evals\n\non:\n  workflow_dispatch:  # manual trigger\n  schedule:\n    - cron: \"0 7 * * *\"  # daily at 7am UTC (1h after recipes)\n\n# Only one eval smoke test run at a time to avoid API contention\nconcurrency:\n  group: smoke-test-evals\n  cancel-in-progress: true\n\njobs:\n  smoke-test:\n    if: github.repository == 'thinking-machines-lab/tinker-cookbook'\n    runs-on: ubuntu-latest\n    timeout-minutes: 10\n\n    steps:\n      - name: checkout\n        uses: actions/checkout@v4\n\n      - name: install-uv\n        uses: astral-sh/setup-uv@v6\n        with:\n          enable-cache: true\n\n      - name: venv\n        run: uv venv && uv sync --all-extras\n\n      - name: run eval smoke tests\n        env:\n          TINKER_API_KEY: ${{ secrets.TINKER_API_KEY }}\n        run: uv run pytest tests/test_inspect_eval.py -v -x -s\n"
  },
  {
    "path": ".github/workflows/smoke-test-recipes.yaml",
    "content": "name: smoke-test-recipes\n\non:\n  workflow_dispatch:  # manual trigger\n  schedule:\n    - cron: \"0 6 * * *\"  # daily at 6am UTC\n\n# Only one smoke test run at a time to avoid API contention\nconcurrency:\n  group: smoke-test-recipes\n  cancel-in-progress: true\n\njobs:\n  # Discover all smoke test files so the matrix is auto-generated.\n  # Adding a new test file in tests/ automatically adds a CI job.\n  discover:\n    if: github.repository == 'thinking-machines-lab/tinker-cookbook'\n    runs-on: ubuntu-latest\n    outputs:\n      tests: ${{ steps.find.outputs.tests }}\n    steps:\n      - name: checkout\n        uses: actions/checkout@v4\n\n      - name: find smoke tests\n        id: find\n        run: |\n          tests=$(find tests/recipes -maxdepth 1 -name 'test_*.py' -printf '%f\\n' \\\n            | sed 's/\\.py$//' \\\n            | jq -R -s -c 'split(\"\\n\") | map(select(length > 0))')\n          echo \"tests=$tests\" >> \"$GITHUB_OUTPUT\"\n\n  smoke-test:\n    if: github.repository == 'thinking-machines-lab/tinker-cookbook'\n    needs: discover\n    runs-on: ubuntu-latest\n    timeout-minutes: 35\n\n    strategy:\n      fail-fast: false\n      matrix:\n        test: ${{ fromJson(needs.discover.outputs.tests) }}\n\n    name: ${{ matrix.test }}\n\n    steps:\n      - name: checkout\n        uses: actions/checkout@v4\n\n      - name: install-uv\n        uses: astral-sh/setup-uv@v6\n        with:\n          enable-cache: true\n\n      - name: venv\n        run: uv venv && uv sync --all-extras\n\n      - name: run smoke test\n        env:\n          TINKER_API_KEY: ${{ secrets.TINKER_API_KEY }}\n          HF_TOKEN: ${{ secrets.HF_TOKEN }}\n        run: uv run pytest tests/recipes/${{ matrix.test }}.py -v -x -s\n"
  },
  {
    "path": ".gitignore",
    "content": "**/__pycache__\ntinker_cookbook/_version.py\n.DS_Store\n.env\n.env.*\n.venv\nuv.lock\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "default_install_hook_types: [pre-commit, pre-push]\n\nrepos:\n- repo: https://github.com/pre-commit/pre-commit-hooks\n  rev: v5.0.0\n  hooks:\n    - id: check-added-large-files\n      args: [\"--maxkb=500\"]\n    - id: end-of-file-fixer\n      exclude: |\n          (?x)\n          ^(\n            \\.sync_state\n          )$\n    - id: trailing-whitespace\n\n- repo: https://github.com/astral-sh/ruff-pre-commit\n  rev: v0.13.2\n  hooks:\n    # Run the linter.\n    - id: ruff-check\n      exclude: tool_declaration_ts\\.py$\n    # Run the formatter.\n    - id: ruff-format\n      exclude: tool_declaration_ts\\.py$\n"
  },
  {
    "path": ".sync_state",
    "content": "{\n  \"last_synced_sha\": \"b4fee215e812ae5a6b0096ba37b3d9edc4f99cd5\",\n  \"last_sync_time\": \"2025-10-09T00:09:30.116486\"\n}\n"
  },
  {
    "path": "AGENTS.md",
    "content": "# Tinker Cookbook Agent Guide\n\nQuick reference for agents working on `tinker-cookbook`. Full documentation is in `docs/`.\n\n`tinker-cookbook` is a client library with training and eval code built on the Tinker service (hosted by Thinking Machines Lab) and the Tinker SDK (a separate repo with just the API). You author training/eval loops that run on a CPU machine; Tinker executes the heavy GPU work.\n\n**Start here:** `docs/training-sampling.mdx` - Complete walkthrough of training and sampling basics.\n\n## Documentation Map (`docs/`)\n\n**API Fundamentals:**\n- `index.mdx` - Tinker overview, division of responsibilities\n- `install.mdx` - Installation, API key setup\n- `training-sampling.mdx` - **Starter guide**: data prep, forward_backward, sampling, vision inputs\n- `losses.mdx` - Loss functions (cross_entropy, importance_sampling, ppo, cispo, dro, forward_backward_custom)\n- `save-load.mdx` - Checkpointing (save_weights_for_sampler vs save_state)\n- `async.mdx` - Sync/async APIs, futures, overlapping requests\n- `model-lineup.mdx` - Available models\n- `under-the-hood.mdx` - Clock cycles, worker pools\n\n**API Reference (`api-reference/`):**\n- `types.md` - **All API types** (Datum, ModelInput, TensorData, SamplingParams, etc.)\n- `trainingclient.md`, `samplingclient.md`, `serviceclient.md`, `restclient.md` - Client APIs\n\n**Supervised Learning (`supervised-learning/`):**\n- `../supervised-learning.mdx` - SL overview\n- `sl-basic.mdx` - First SL run\n- `sl-hyperparams.mdx` - LR formula, batch size\n- `sl-loop.mdx` - Minimal training loop\n- `prompt-distillation.mdx` - Distilling prompts\n- `sweep-case-study.mdx` - Hyperparameter sweeps\n\n**Reinforcement Learning (`rl/`):**\n- `../rl.mdx` - RL overview (RLVR, RLHF)\n- `rl-basic.mdx` - First RL run\n- `rl-envs.mdx` - Custom Env, EnvGroupBuilder, RLDataset\n- `rl-loops.mdx` - Minimal RL loop\n- `rl-hyperparams.mdx` - batch_size vs group_size, async training\n- `sequence-extension.mdx` - Multi-turn RL, KV-cache\n\n**Preferences (`preferences/`):**\n- `../preferences.mdx` - DPO vs RLHF overview\n- `dpo-guide.mdx` - DPO training\n- `rlhf-example.mdx` - RLHF pipeline\n\n**Other:**\n- `rendering.mdx` - Renderers (bridge between chat-style data and token sequences), vision inputs, TrainOnWhat\n- `completers.mdx` - TokenCompleter vs MessageCompleter\n- `evals.mdx` - Inline evals, Inspect AI, custom evaluators\n- `lora-primer.mdx` - LoRA background\n- `download-weights.mdx` / `publish-weights.mdx` - Weight export\n\n---\n\n## Composing Types\n\nAgents often struggle with the nested type hierarchy. Key resources:\n\n**Reference:** `docs/api-reference/types.md` documents all API types.\n\n**Core types:**\n- `Datum` = `model_input` (ModelInput) + `loss_fn_inputs` (dict of TensorData)\n- `ModelInput` = list of chunks (EncodedTextChunk, ImageChunk)\n- `TensorData` = wrapper for numpy/torch arrays with shape info\n\n**Helper functions** (use these instead of manual construction):\n- `datum_from_model_input_weights(model_input, weights, max_length)` - SL datum creation (`supervised/common.py`)\n- `conversation_to_datum(messages, renderer, max_length, train_on_what)` - Full pipeline (`supervised/data.py`)\n- `renderer.build_supervised_example(messages)` - Returns (ModelInput, weights)\n- `ModelInput.from_ints(tokens)` - Create from token list\n- `TensorData.from_numpy(arr)` / `TensorData.from_torch(tensor)` - Wrap arrays\n\n---\n\n## Architecture\n\n**Builder pattern:** Config objects are `chz` dataclasses (SupervisedDatasetBuilder, RLDatasetBuilder, EnvGroupBuilder). They expose `.build()`/`__call__()` returning runtime objects.\n\n**Key code locations:**\n- SL: `tinker_cookbook/supervised/train.py`\n- RL: `tinker_cookbook/rl/train.py`\n- DPO: `tinker_cookbook/preference/train_dpo.py`\n- Renderers: `tinker_cookbook/renderers/`\n- Completers: `tinker_cookbook/completers.py`\n- RL types: `tinker_cookbook/rl/types.py`\n- Rollout strategies: `tinker_cookbook/rl/rollout_strategy.py` (FailFast, RetryOnFailure)\n- Logging: `tinker_cookbook/utils/logtree.py`, `tinker_cookbook/rl/rollouts.py`\n- Recipes: `tinker_cookbook/recipes/`\n\n**Training outputs:** RL and SL training write human-readable HTML reports and machine-readable JSON files (metrics, rollout transcripts, per-trajectory summaries) to `log_path`. Point agents at a `log_path` directory to analyze training runs — `metrics.jsonl` for scalar metrics, `*_rollout_summaries.jsonl` for per-trajectory data, and `*_logtree.json` for full rollout transcripts including model responses. See `docs/rl/rl-logging.mdx` for the complete file reference and parsing examples.\n\n---\n\n## Conventions\n\n**Subscript suffixes** for tensor names: `_P` (problems), `_G` (groups), `_T` (tokens), `_D` (datums). Example: `tokens_P_G_T[p][g][t]`\n\n**Code style:**\n- Explicit typing; avoid `Any` / `type: ignore`\n- Use `safezip`, `timed`, `scope` helpers\n- `@chz.chz` decorator for config serialization\n- `ml_log.log_metrics` for metrics; `logtree` for transcripts\n\n**Env lifecycle:** `Env` objects are single-use (no reset). Create via `EnvGroupBuilder`.\n\n---\n\n## Common Pitfalls\n\n1. **LoRA LR:** Use `hyperparam_utils.get_lr(model_name)` - LoRA needs ~10x higher LR than full fine-tuning.\n\n2. **Renderer mismatch:** Match `renderer_name` to model family (`llama3`, `qwen3`, `role_colon`).\n\n3. **Async gaps:** Submit `forward_backward_async` and `optim_step_async` back-to-back before awaiting.\n\n4. **Sampler desync:** Create a **new** sampling client after saving weights.\n\n5. **Type construction:** Use helper functions, not manual dict construction. See `supervised/data.py` and `supervised/common.py`.\n\n6. **Group semantics:** RL advantages are centered within each group.\n\n7. **DPO:** Start with `dpo_beta=0.1`, LR~1e-5.\n\n---\n\n## Testing\n\n```bash\n# Unit tests (no API needed, colocated *_test.py files)\npytest tinker_cookbook/\n\n# Smoke tests (requires TINKER_API_KEY + network)\npytest tests/\n```\n\nFor debugging, shrink workloads via `n_batches`, `batch_size`, `group_size` in dataset builders.\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\n\nA curated feed of notable changes to `tinker-cookbook`. Small bugfixes and minor argument additions are omitted—this is for changes worth knowing about.\n\n## Format\n\nEach entry includes:\n- **Title**: A short, human-readable summary (not the commit message)\n- **Date**: When it was merged\n- **Type**: `new` (feature), `improvement` (enhancement to existing functionality), or `fix`\n- **Tags**: What area it touches (e.g., `renderers`, `rl`, `supervised`, `eval`, `datasets`)\n- **PR**: Link to the pull request\n\n---\n\n### [cookbook] Cap training steps with `max_step` parameter ([#328](https://github.com/thinking-machines-lab/tinker-cookbook/pull/328))\n**Date:** 2026-01-28\n**Type:** new\n**Tags:** rl, supervised\n\nAdds optional `max_step` config parameter to cap training steps in on-policy distillation. When set, trains for `min(max_step, dataset_length)`. Default `None` preserves existing behavior.\n\n---\n\n### [cookbook] Configurable KL penalty reference model ([#326](https://github.com/thinking-machines-lab/tinker-cookbook/pull/326))\n**Date:** 2026-01-27\n**Type:** new\n**Tags:** rl\n\nMakes the KL penalty reference model configurable in RL training. Users can now specify a different base model or a checkpoint for the KL penalty computation, rather than using the default.\n\n---\n\n### [cookbook] Checkpoints now have 7-day TTL by default ([#324](https://github.com/thinking-machines-lab/tinker-cookbook/pull/324))\n**Date:** 2026-01-27\n**Type:** improvement\n**Tags:** infrastructure\n\nCheckpoints are now set to auto-expire after 7 days by default, helping users avoid unexpected storage costs.\n\n---\n\n### [cookbook] Support for dedicated capacity ([#315](https://github.com/thinking-machines-lab/tinker-cookbook/pull/315))\n**Date:** 2026-01-21\n**Type:** new\n**Tags:** infrastructure\n\nAdds support for dedicated capacity in training configurations.\n\n---\n\n### [cookbook] Modal sandbox backend for code execution ([#278](https://github.com/thinking-machines-lab/tinker-cookbook/pull/278), [#291](https://github.com/thinking-machines-lab/tinker-cookbook/pull/291), [#300](https://github.com/thinking-machines-lab/tinker-cookbook/pull/300), [#302](https://github.com/thinking-machines-lab/tinker-cookbook/pull/302))\n**Date:** 2026-01-07 to 2026-01-15\n**Type:** new\n**Tags:** sandboxes, rl\n\nAdds Modal as an alternative sandbox backend for code execution alongside SandboxFusion. Includes:\n- `ModalSandbox` and `ModalSandboxPool` for managing sandboxes\n- Warm pool maintenance with configurable timeouts\n- Rate limiting to respect Modal account limits\n- Async API calls for better performance\n- Documentation for both sandbox backends\n\nSee `tinker_cookbook/sandbox/` for the new module structure.\n\n---\n\n### [cookbook] Fix streaming dataset batch skipping ([#295](https://github.com/thinking-machines-lab/tinker-cookbook/pull/295))\n**Date:** 2026-01-19\n**Type:** fix\n**Tags:** supervised\n\nHuggingFace's shuffle is deterministic, so batch skipping now works correctly with streaming datasets. Forward skipping through batches no longer causes data inconsistencies.\n\n---\n\n### [cookbook] Fix supervised metrics from OptimStepResponse ([#286](https://github.com/thinking-machines-lab/tinker-cookbook/pull/286))\n**Date:** 2026-01-20\n**Type:** fix\n**Tags:** supervised\n\nPreviously, optimization metrics (like gradient norms) from `OptimStepResponse` were being dropped in `finish_batch`. Metrics are now properly captured and merged into the step's metrics dictionary.\n\n---\n\n### [cookbook] Adapter to base-model merge script ([#292](https://github.com/thinking-machines-lab/tinker-cookbook/pull/292))\n**Date:** 2026-01-08\n**Type:** new\n**Tags:** tools\n\nNew script to merge LoRA/adapter weights back into the base model.\n\n---\n\n### [cookbook] Fix inspect_utils for list content from parse_response ([#299](https://github.com/thinking-machines-lab/tinker-cookbook/pull/299))\n**Date:** 2026-01-12\n**Type:** fix\n**Tags:** eval\n\nFixed `inspect_utils.py` which assumed `parse_response` always returns string content. Renderers like `Qwen3Renderer` return list content (with `ThinkingPart`, `ToolCallPart`, etc.) when responses contain `<think>` or `<tool_call>` blocks. Now uses `renderers.get_text_content()` which handles both formats.\n\n---\n\n### [cookbook] Fix Kimi K2 and DeepSeek V3 renderer parsing ([#279](https://github.com/thinking-machines-lab/tinker-cookbook/pull/279), [#285](https://github.com/thinking-machines-lab/tinker-cookbook/pull/285))\n**Date:** 2026-01-05 to 2026-01-07\n**Type:** fix\n**Tags:** renderers\n\nFixes tool declaration rendering for Kimi K2 and Qwen3 to match HuggingFace templates. Also fixes DeepSeekV3ThinkingRenderer to properly parse thinking traces via a round-trip test ensuring `build_supervised_example` and `parse_response` correspondence.\n\n---\n\n### [sdk] Torch is now an optional dependency ([#15](https://github.com/thinking-machines-lab/tinker/pull/15))\n**Date:** 2026-01-20\n**Type:** improvement\n**Tags:** dependencies\n\nMoves torch to an optional dependency in the SDK. Applications that don't need torch for training can now use the SDK without installing it. Import guards added to `training_client.py`.\n\n---\n\n### Major renderer overhaul: tool calling, structured content ([#220](https://github.com/thinking-machines-lab/tinker-cookbook/pull/220), [#221](https://github.com/thinking-machines-lab/tinker-cookbook/pull/221), [#238](https://github.com/thinking-machines-lab/tinker-cookbook/pull/238), [#243](https://github.com/thinking-machines-lab/tinker-cookbook/pull/243), [#244](https://github.com/thinking-machines-lab/tinker-cookbook/pull/244), [#250](https://github.com/thinking-machines-lab/tinker-cookbook/pull/250))\n**Date:** 2025-12-26 to 2025-12-28\n**Type:** improvement\n**Tags:** renderers, rl\n\nA series of PRs that significantly improve the renderer system:\n\n**Tool calling support:** New `ToolSpec` type for defining tools and `create_conversation_prefix_with_tools()` API on all renderers. Tool call parsing supported for Qwen3, DeepSeek V3, and Kimi K2. `UnparsedToolCall` captures tool calls that fail to parse.\n\n**Structured message content:** The `Message.thinking` field is removed (**breaking**). Thinking content is now represented as `ThinkingPart` in the content list, alongside `TextPart`, `ImagePart`, and `ToolCallPart`. Use `get_text_content(message)` to extract text after `parse_response`.\n\n**Clearer field names:** `RenderedMessage` fields renamed (**breaking**): `prefix` → `header`, `content` → `output`, `suffix` → `stop_overlap`. `Renderer` changed from Protocol to ABC.\n\n**Sequence extension property:** New `has_extension_property` on `Renderer` indicates whether consecutive timesteps can be merged for O(T) instead of O(T²) compute in multi-turn RL.\n\n**Modular architecture:** `renderers.py` split into `tinker_cookbook/renderers/` package with per-model modules (`qwen3.py`, `deepseek_v3.py`, `kimi_k2.py`, etc.). Imports unchanged.\n\n**HF compatibility:** Various fixes to match HuggingFace chat templates, with expanded test coverage using random conversation generation.\n\n---\n\n### Qwen3 thinking blocks can now be preserved in history ([#142](https://github.com/thinking-machines-lab/tinker-cookbook/pull/142))\n**Date:** 2025-12-06\n**Type:** new\n**Tags:** renderers, rl\n\nThe Qwen3Renderer now has a `strip_thinking_from_history` option. By default (`True`), it strips `<think>...</think>` blocks from previous assistant turns—matching how Qwen3 was trained. Set it to `False` if you're doing multi-turn RL and want to use sequence extension: preserving thinking lets turns merge into one sequence, reducing compute cost.\n\n---\n\n### Disable checkpoint saving with `save_every=0` ([#149](https://github.com/thinking-machines-lab/tinker-cookbook/pull/149))\n**Date:** 2025-12-06\n**Type:** improvement\n**Tags:** supervised, rl\n\nSetting `save_every=0` now disables checkpoint saving entirely (previously it crashed with a divide-by-zero). Useful for quick test runs where you don't need checkpoints.\n\n---\n\n### xmux: launch experiment sweeps in tmux ([#138](https://github.com/thinking-machines-lab/tinker-cookbook/pull/138))\n**Date:** 2025-12-02\n**Type:** new\n**Tags:** tools\n\nNew `xmux` utility for running experiment sweeps. It spawns parallel jobs in a tmux session where you can monitor each experiment's progress in separate windows. See `tinker_cookbook/xmux/examples/` for usage.\n\n---\n\n### Optimizer state now loads correctly on resume ([#140](https://github.com/thinking-machines-lab/tinker-cookbook/pull/140), [#141](https://github.com/thinking-machines-lab/tinker-cookbook/pull/141))\n**Date:** 2025-12-02\n**Type:** fix\n**Tags:** supervised, rl\n\nTraining resumption now properly loads optimizer state (momentum, etc.) alongside model weights. Previously, `load_state()` didn't restore the optimizer, which could affect training dynamics after a checkpoint resume.\n\n---\n\n### Tracing support for supervised training ([#88](https://github.com/thinking-machines-lab/tinker-cookbook/pull/88))\n**Date:** 2025-11-21\n**Type:** new\n**Tags:** supervised, tools\n\nSet `enable_trace=True` to generate trace events during supervised training. Visualize with Perfetto to see where time is spent. Run `python -m tinker_cookbook.utils.trace` to convert the trace file.\n\n---\n\n### Code RL recipe with DeepCoder ([#83](https://github.com/thinking-machines-lab/tinker-cookbook/pull/83))\n**Date:** 2025-11-18\n**Type:** new\n**Tags:** recipes, rl\n\nNew recipe for RL on competitive programming problems using the DeepCoder dataset. Code execution is sandboxed via Sandbox Fusion. See `tinker_cookbook/recipes/code_rl/`.\n\n---\n\n### Configurable temperature for RL sampling ([#86](https://github.com/thinking-machines-lab/tinker-cookbook/pull/86))\n**Date:** 2025-11-17\n**Type:** new\n**Tags:** rl\n\nTemperature is now a configurable parameter in RL configs. Previously hardcoded to 1.0.\n\n---\n\n### Per-message training control with `TrainOnWhat.CUSTOMIZED` ([#85](https://github.com/thinking-machines-lab/tinker-cookbook/pull/85))\n**Date:** 2025-11-14\n**Type:** new\n**Tags:** supervised, renderers\n\nNew `TrainOnWhat.CUSTOMIZED` option lets you set a `trainable: bool` field on each message to control which messages get loss applied. Useful for training on specific turns in a conversation.\n\n---\n\n### Interactive environment debugging with `play_w_env` ([#76](https://github.com/thinking-machines-lab/tinker-cookbook/pull/76))\n**Date:** 2025-11-07\n**Type:** new\n**Tags:** rl, tools\n\nNew utility to \"role-play\" as the policy and interact with an Environment. Useful for debugging reward functions and environment logic. See `tinker_cookbook/recipes/multiplayer_rl/twenty_questions/play.py` for an example.\n\n---\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to Tinker Cookbook\n\nWe welcome contributions! This project is built in the spirit of open science and collaborative development.\n\n## Development setup\n\n```bash\ngit clone https://github.com/thinking-machines-lab/tinker-cookbook.git\ncd tinker-cookbook\nuv sync --extra dev\npre-commit install\n```\n\nThis installs dev dependencies and registers pre-commit hooks that run `ruff` formatting and linting on every commit.\n\n## Running tests\n\n```bash\n# Unit tests (no API key needed, colocated *_test.py files)\nuv run pytest tinker_cookbook/\n\n# Integration tests (requires TINKER_API_KEY)\nuv run pytest tests/\n```\n\n## Code style\n\nWe use [ruff](https://docs.astral.sh/ruff/) for linting and formatting (line length: 100). Pre-commit hooks run automatically on each commit.\n\n```bash\nuv run ruff check tinker_cookbook/\nuv run ruff format tinker_cookbook/\n```\n\n## Type checking\n\nWe use [pyright](https://github.com/microsoft/pyright) for static type analysis. Please use typing wherever possible; avoid `Any` and `type: ignore`; prefer casting. However, avoid convoluted generics or overly verbose code just to satisfy the type checker. Prefer single types over union types.\n\n```bash\nuv run pyright tinker_cookbook\n```\n\n## Pull request process\n\n1. Create a feature branch from `main`\n2. Make your changes with tests if applicable\n3. Ensure all checks pass: `pre-commit run --all-files`\n4. Open a PR with a clear description of the change\n\nCI runs pre-commit, pyright, and pytest on every PR.\n\n## Project structure\n\n- `tinker_cookbook/` — Library code (supervised learning, RL, renderers, utilities)\n- `tinker_cookbook/recipes/` — Example training scripts\n- `tests/` — Integration tests (require API key)\n- `docs/` — Documentation (MDX format, synced to docs site)\n\n---\n\n# Design conventions\n\n## Organization of training scripts\n\nWe're designing the codebase with the following goals:\n\n1. Low barrier to entry: it should be dead simple to run something and see numbers go up.\n2. Extensible: it should be possible to pass in custom datasets and evals and control all the hyperparameters.\n3. Science-friendly: it should be easy to run sweeps, and analyze the results.\n\nTo achieve this, we'll use the following structure around training scripts:\n\n- There's a main training function, such as [rl/train.py](tinker_cookbook/rl/train.py) or [supervised/train.py](tinker_cookbook/supervised/train.py), which contains the main loop.\n    - This function contains a detailed config object (`Config`), which isn't constructable from the command line.\n    - The config contains members that specify things like datasets and evals. However, these should be chz configs (with a `.build` method that constructs the actual object) or callables (we recommend using functools.partial). This way, the config is serializable, which is useful for sweeps.\n- There are launch scripts that assemble training configs (e.g., [recipes/math_rl/train.py](tinker_cookbook/recipes/math_rl/train.py)), which construct a smaller config object (`CLIConfig`) from the command line.\n\n## Async\n\nAsync is very useful for RL, where it allows us to make many queries in parallel (e.g., sampling calls). For all of the interfaces used in RL (such as the `Env` class), all the methods that take nontrivial amounts of time should be async. For some of the other code, such as [recipes/sl_loop.py](tinker_cookbook/recipes/sl_loop.py), we've chosen not to use async methods, just to make it more beginner-friendly, as many python programmers are not familiar with async.\n\n## Classes\n\nThere are a lot of different classes, which might make the code feel less approachable. However, they follow *the builder pattern*, and the code should be less confusing when you know the pattern.\n\nWe can illustrate the pattern with the two main examples:\n\n- A `SupervisedDatasetBuilder` is a configuration object which builds a `SupervisedDataset`.\n- An `RLDatasetBuilder` is a configuration object which builds an `RLDataset`, which generates batches of `EnvGroupBuilder` objects, which each generate a group of `Env` objects.\n\nHere, the `SupervisedDatasetBuilder`, `RLDatasetBuilder`, and `EnvGroupBuilder` are all configuration objects, which have a `__call__` method that builds another object. You can see these objects in [supervised/types.py](tinker_cookbook/supervised/types.py) and [rl/types.py](tinker_cookbook/rl/types.py).\n\nIn general, we use a lot of configuration objects, with a `__call__` method that returns a heavyweight object (like a dataset). We use `chz` for the configuration objects -- it's similar to a dataclass but with some extra features that are nice for configs. We use either dataclasses or regular python classes for the heavyweight objects.\n\n## Envs\n\nAn `Env` is an RL environment. For those with an RL background, it roughly corresponds to an MDP or a POMDP, however we use in more general cases (such as multi-agent settings) that don't strictly correspond to the MDP/POMDP formalism. It's roughly analogous the concept of an Env in OpenAI Gym, but unlike OpenAI Gym, we don't have a `reset` method; rather, the env should be discarded after a rollout. Any shared resources should be maintained by whatever object is creating the envs.\n\nThe `Env`s are created by `EnvGroupBuilder`s. The group of envs returned by `EnvGroupBuilder` have something in common; either they correspond to the same task (in which case we can use this information for variance reduction, as in GRPO, which centers per group); or, we can use the group to define a multi-agent environment.\n\n- One common multi-agent environment is where we use a pairwise preference model to compare pairs of completions.\n- We can also use the group to define a two-player game. Some two player games such as tic-tac-toe are currently supported through the [text_arena](tinker_cookbook/recipes/multiplayer_rl/text_arena/env.py) environments.\n\n\n## Notation\n\nWe'll use subscripts to indicate the shapes of objects. For example, `tokens_P_G_T` indicates a three-dimensional array of tokens, with `P` problems, `G` groups, and `T` tokens per groups, so `tokens_P_G_T[p][g][t]` should refer to a single token. In many cases, the arrays will be ragged. E.g., the `T` axis will have different lengths for different `(p,g)`. Sometimes, a given dimension will be flattened from two dimensions. If we write `tokens_PG_T`, that means that we have a two dimensional array, where the 0th dimension is flattened from the `P` and `G` dimensions.\n\n### Common Dimension Names\n\nHere are the standard dimension subscripts used throughout the codebase:\n\n- `_D`: Data/Datum dimension (for training data items)\n- `_G`: Group dimension (for multiple attempts/rollouts of the same problem)\n- `_P`: Problem dimension (for different problems/prompts)\n- `_T`: Token/Time dimension (for sequences)\n\nThe relationship between dimensions in RL:\n- A batch contains multiple problems (`_P`)\n- Each problem spawns multiple attempts/environments (`_G`), forming a group\n- Each attempt produces one trajectory\n- Advantages are normalized within each group (across the `_G` dimension)\n\nExamples:\n- `env_group_builders_P`: A list of environment builders, one per problem\n- `trajectories_G`: Multiple trajectories from attempts at the same problem\n- `rewards_G`: Rewards for each attempt within a group\n- `tokens_P_G_T`: Tokens with problem, group, and time dimensions\n- `data_D`: A list of training data items\n\n## Questions?\n\nEmail us at tinker@thinkingmachines.ai.\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2025 Thinking Machines Lab\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<h1 align=\"center\">Tinker Cookbook</h1>\n<div align=\"center\">\n  <img src=\"assets/tinker-cover.png\" width=\"60%\" />\n</div>\n\n<div align=\"center\">\n\n[![pytest](https://github.com/thinking-machines-lab/tinker-cookbook/actions/workflows/pytest.yaml/badge.svg)](https://github.com/thinking-machines-lab/tinker-cookbook/actions/workflows/pytest.yaml)\n[![pyright](https://github.com/thinking-machines-lab/tinker-cookbook/actions/workflows/pyright.yaml/badge.svg)](https://github.com/thinking-machines-lab/tinker-cookbook/actions/workflows/pyright.yaml)\n[![smoke-test-recipes](https://github.com/thinking-machines-lab/tinker-cookbook/actions/workflows/smoke-test-recipes.yaml/badge.svg)](https://github.com/thinking-machines-lab/tinker-cookbook/actions/workflows/smoke-test-recipes.yaml)\n[![PyPI](https://img.shields.io/pypi/v/tinker-cookbook)](https://pypi.org/project/tinker-cookbook/)\n\n</div>\n\nWe provide two libraries for the broader community to customize their language models: `tinker` and `tinker-cookbook`.\n\n- `tinker` is a training SDK for researchers and developers to fine-tune language models. You send API requests to us and we handle the complexities of distributed training.\n- `tinker-cookbook` includes realistic examples of fine-tuning language models. It builds on the Tinker API and provides common abstractions to fine-tune language models.\n\n## Installation\n\n1. Sign up for Tinker [here](https://auth.thinkingmachines.ai/sign-up).\n2. Once you have access, create an API key from the [console](https://tinker-console.thinkingmachines.ai) and export it as environment variable `TINKER_API_KEY`.\n3. Install `tinker-cookbook` (includes the `tinker` SDK as a dependency):\n   ```bash\n   # Latest stable release from PyPI\n   uv pip install tinker-cookbook\n\n   # Or install the nightly build\n   uv pip install 'tinker-cookbook @ git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly'\n   ```\n\n## Tinker\n\nRefer to the [docs](https://tinker-docs.thinkingmachines.ai/training-sampling) to start from basics.\nHere we introduce a few Tinker primitives - the basic components to fine-tune LLMs:\n\n```python\nimport tinker\nservice_client = tinker.ServiceClient()\ntraining_client = service_client.create_lora_training_client(\n  base_model=\"meta-llama/Llama-3.2-1B\", rank=32,\n)\ntraining_client.forward_backward(...)\ntraining_client.optim_step(...)\ntraining_client.save_state(...)\ntraining_client.load_state(...)\n\nsampling_client = training_client.save_weights_and_get_sampling_client(name=\"my_model\")\nsampling_client.sample(...)\n```\n\nSee [tinker_cookbook/recipes/sl_loop.py](tinker_cookbook/recipes/sl_loop.py) and [tinker_cookbook/recipes/rl_loop.py](tinker_cookbook/recipes/rl_loop.py) for minimal examples of using these primitives to fine-tune LLMs.\n\nTo download the weights of any model:\n```python\nrest_client = service_client.create_rest_client()\nfuture = rest_client.get_checkpoint_archive_url_from_tinker_path(sampling_client.model_path)\nwith open(f\"model-checkpoint.tar.gz\", \"wb\") as f:\n    f.write(future.result())\n```\n\n### Tinker Cookbook\n\nBesides these primitives, we also offer **Tinker Cookbook** (a.k.a. this repo), a library of a wide range of abstractions to help you customize training environments.\n[`tinker_cookbook/recipes/sl_basic.py`](tinker_cookbook/recipes/sl_basic.py) and [`tinker_cookbook/recipes/rl_basic.py`](tinker_cookbook/recipes/rl_basic.py) contain minimal examples to configure supervised learning and reinforcement learning.\n\nWe also include a wide range of more sophisticated examples in the [`tinker_cookbook/recipes/`](tinker_cookbook/recipes/) folder:\n1. **[Chat supervised learning](tinker_cookbook/recipes/chat_sl/)**: supervised fine-tuning on conversational datasets like Tulu3.\n2. **[Math reasoning](tinker_cookbook/recipes/math_rl/)**: improve LLM reasoning capability by rewarding it for answering math questions correctly.\n3. **[Preference learning](tinker_cookbook/recipes/preference/)**: showcase a three-stage RLHF pipeline: 1) supervised fine-tuning, 2) learning a reward model, 3) RL against the reward model.\n4. **[Tool use](tinker_cookbook/recipes/search_tool/)**: train LLMs to better use retrieval tools to answer questions more accurately.\n5. **[Prompt distillation](tinker_cookbook/recipes/prompt_distillation/)**: internalize long and complex instructions into LLMs.\n6. **[Multi-Agent](tinker_cookbook/recipes/multiplayer_rl/)**: optimize LLMs to play against another LLM or themselves.\n\nThese examples are located in each subfolder, and their `README.md` files will walk you through the key implementation details, the commands to run them, and the expected performance.\n\n### Documentation\n\nThe `docs/` directory contains a mirror of the Tinker documentation. These files are synced from our internal documentation site.\n\n**Note:** The documentation files use MDX format (Markdown with JSX), which includes some syntax that isn't standard Markdown. You may see things like `import` statements, `<Callout>` components, or curly-brace expressions. These are artifacts of our documentation framework - the actual content should still be readable as Markdown.\n\nIf you find errors or want to improve the documentation, feel free to submit a PR editing files in `docs/`. We'll sync the changes back to our documentation site.\n\nFor the rendered documentation, visit [tinker-docs.thinkingmachines.ai](https://tinker-docs.thinkingmachines.ai).\n\n### Import our utilities\n\nTinker cookbook includes several utilities. Here's a quick overview:\n- [`renderers`](tinker_cookbook/renderers/) converts tokens from/to structured chat message objects\n- [`hyperparam_utils`](tinker_cookbook/hyperparam_utils.py) helps calculate hyperparameters suitable for LoRAs\n- [`evaluation`](tinker_cookbook/eval/evaluators.py) provides abstractions for evaluating Tinker models and [`inspect_evaluation`](tinker_cookbook/eval/inspect_evaluators.py) shows how to integrate with InspectAI to make evaluating on standard benchmarks easy.\n\n## Development Setup\n\n```bash\nuv sync --extra dev\npre-commit install\n```\n\nThis installs dev dependencies and registers pre-commit hooks that run `ruff` formatting and linting on every commit. CI enforces these checks on all pull requests.\n\n## Contributing\n\nThis project is built in the spirit of open science and collaborative development. We believe that the best tools emerge through community involvement and shared learning.\n\nWe welcome PR contributions after our private beta is over. If you have any feedback, please email us at tinker@thinkingmachines.ai.\n\n## Citation\nIf you use Tinker for your research, please cite it as:\n```\nThinking Machines Lab, 2025. Tinker. https://thinkingmachines.ai/tinker/.\n```\n\nOr use this BibTeX citation:\n```\n@misc{tml2025tinker,\n  author = {Thinking Machines Lab},\n  title = {Tinker},\n  year = {2025},\n  url = {https://thinkingmachines.ai/tinker/},\n}\n```\n"
  },
  {
    "path": "docs/api-reference/apifuture.md",
    "content": "API Future classes for handling async operations with retry logic.\n\n## `APIFuture` Objects\n\n```python\nclass APIFuture(ABC, Generic[T])\n```\n\nAbstract base class for futures that can be awaited or accessed synchronously.\n\nAPIFuture provides a unified interface for handling async operations that can\nbe accessed both synchronously (via result()) and asynchronously (via await or result_async()).\nThis allows for flexible usage patterns in both sync and async contexts.\n\nThe future can be awaited directly in async contexts:\n```python\nresult = await api_future  # Equivalent to await api_future.result_async()\n```\n\nOr accessed synchronously:\n```python\nresult = api_future.result()  # Blocks until complete\n```\n\nArgs:\n- `T`: The type of the result value\n\nExample:\n```python\n# In async context\nfuture = training_client.forward_backward(data, \"cross_entropy\")\nresult = await future  # Or await future.result_async()\n\n# In sync context\nfuture = training_client.forward_backward(data, \"cross_entropy\")\nresult = future.result()\n```\n\n#### `result_async`\n\n```python\nasync def result_async(timeout: float | None = None) -> T\n```\n\nGet the result asynchronously with optional timeout.\n\nArgs:\n- `timeout`: Maximum time to wait in seconds. None means wait indefinitely.\n\nReturns:\n- The result value of type `T`\n\nRaises:\n    TimeoutError: If timeout is exceeded\n\n#### `result`\n\n```python\ndef result(timeout: float | None = None) -> T\n```\n\nGet the result synchronously with optional timeout.\n\nArgs:\n- `timeout`: Maximum time to wait in seconds. None means wait indefinitely.\n\nReturns:\n- The result value of type `T`\n\nRaises:\n    TimeoutError: If timeout is exceeded\n\n## `AwaitableConcurrentFuture` Objects\n\n```python\nclass AwaitableConcurrentFuture(APIFuture[T])\n```\n\nImplementation of APIFuture that wraps a concurrent.futures.Future.\n\nThis class bridges Python's concurrent.futures with asyncio, allowing a\nstandard Future to be used in async contexts. It's commonly returned by\nTinker API methods to provide both sync and async access patterns.\n\nArgs:\n- `future`: A concurrent.futures.Future to wrap\n\nExample:\n```python\n# Internal usage - typically you receive these from API methods\nconcurrent_future = some_operation()\napi_future = AwaitableConcurrentFuture(concurrent_future)\n\n# Can be used synchronously\nresult = api_future.result()\n\n# Or asynchronously\nresult = await api_future\n```\n\n#### `result`\n\n```python\ndef result(timeout: float | None = None) -> T\n```\n\nGet the result synchronously with optional timeout.\n\nArgs:\n- `timeout`: Maximum time to wait in seconds. None means wait indefinitely.\n\nReturns:\n- The result value of type `T`\n\nRaises:\n    TimeoutError: If timeout is exceeded\n    Exception: Any exception raised by the underlying operation\n\nExample:\n```python\nfuture = rest_client.get_training_run(\"run-id\")\nresult = future.result(timeout=30)  # Wait up to 30 seconds\n```\n\n#### `result_async`\n\n```python\nasync def result_async(timeout: float | None = None) -> T\n```\n\nAsync version of result.\n\n#### `future`\n\n```python\ndef future() -> ConcurrentFuture[T]\n```\n\nGet the underlying concurrent.futures.Future.\n\nReturns:\n- The wrapped `ConcurrentFuture` object\n\nExample:\n```python\napi_future = rest_client.get_training_run(\"run-id\")\nconcurrent_future = api_future.future()\n# Can now use standard concurrent.futures methods\nif concurrent_future.done():\n    result = concurrent_future.result()\n```\n"
  },
  {
    "path": "docs/api-reference/exceptions.md",
    "content": "## `TinkerError` Objects\n\n```python\nclass TinkerError(Exception)\n```\n\nBase exception for all Tinker-related errors.\n\n## `APIError` Objects\n\n```python\nclass APIError(TinkerError)\n```\n\nBase class for all API-related errors.\n\n#### `body`\n\nThe API response body.\n\nIf the API responded with a valid JSON structure then this property will be the\ndecoded result.\n\nIf it isn't a valid JSON structure then this will be the raw response.\n\nIf there was no response associated with this error then it will be `None`.\n\n## `APIResponseValidationError` Objects\n\n```python\nclass APIResponseValidationError(APIError)\n```\n\nRaised when API response doesn't match expected schema.\n\n## `APIStatusError` Objects\n\n```python\nclass APIStatusError(APIError)\n```\n\nRaised when an API response has a status code of 4xx or 5xx.\n\n## `APIConnectionError` Objects\n\n```python\nclass APIConnectionError(APIError)\n```\n\nRaised when a connection error occurs while making an API request.\n\n## `APITimeoutError` Objects\n\n```python\nclass APITimeoutError(APIConnectionError)\n```\n\nRaised when an API request times out.\n\n## `BadRequestError` Objects\n\n```python\nclass BadRequestError(APIStatusError)\n```\n\nHTTP 400: The request was invalid or malformed.\n\n## `AuthenticationError` Objects\n\n```python\nclass AuthenticationError(APIStatusError)\n```\n\nHTTP 401: Authentication credentials are missing or invalid.\n\n## `PermissionDeniedError` Objects\n\n```python\nclass PermissionDeniedError(APIStatusError)\n```\n\nHTTP 403: Insufficient permissions to access the resource.\n\n## `NotFoundError` Objects\n\n```python\nclass NotFoundError(APIStatusError)\n```\n\nHTTP 404: The requested resource was not found.\n\n## `ConflictError` Objects\n\n```python\nclass ConflictError(APIStatusError)\n```\n\nHTTP 409: The request conflicts with the current state of the resource.\n\n## `UnprocessableEntityError` Objects\n\n```python\nclass UnprocessableEntityError(APIStatusError)\n```\n\nHTTP 422: The request was well-formed but contains semantic errors.\n\n## `RateLimitError` Objects\n\n```python\nclass RateLimitError(APIStatusError)\n```\n\nHTTP 429: Too many requests, rate limit exceeded.\n\n## `InternalServerError` Objects\n\n```python\nclass InternalServerError(APIStatusError)\n```\n\nHTTP 500+: An error occurred on the server.\n\n## `SidecarError` Objects\n\n```python\nclass SidecarError(TinkerError)\n```\n\nBase exception for subprocess sidecar errors.\n\n## `SidecarStartupError` Objects\n\n```python\nclass SidecarStartupError(SidecarError)\n```\n\nRaised when the sidecar subprocess fails to start or times out.\n\n## `SidecarDiedError` Objects\n\n```python\nclass SidecarDiedError(SidecarError)\n```\n\nRaised when the sidecar subprocess exits unexpectedly while requests are pending.\n\n## `SidecarIPCError` Objects\n\n```python\nclass SidecarIPCError(SidecarError)\n```\n\nRaised when communication with the sidecar subprocess fails.\n\n## `RequestFailedError` Objects\n\n```python\nclass RequestFailedError(TinkerError)\n```\n\nRaised when an asynchronous request completes in a failed state.\n"
  },
  {
    "path": "docs/api-reference/restclient.md",
    "content": "RestClient for Tinker API REST operations.\n\n## `RestClient` Objects\n\n```python\nclass RestClient(TelemetryProvider)\n```\n\nClient for REST API operations like listing checkpoints and metadata.\n\nThe RestClient provides access to various REST endpoints for querying\nmodel information, checkpoints, and other resources. You typically get one\nby calling `service_client.create_rest_client()`.\n\nKey methods:\n- list_checkpoints() - list available model checkpoints (both training and sampler)\n- list_user_checkpoints() - list all checkpoints across all user's training runs\n- get_training_run() - get model information and metadata as ModelEntry\n- delete_checkpoint() - delete an existing checkpoint for a training run\n- get_checkpoint_archive_url() - get signed URL to download checkpoint archive\n- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public\n- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private\n- set_checkpoint_ttl_from_tinker_path() - set or remove TTL on a checkpoint\n\nArgs:\n- `holder`: Internal client managing HTTP connections and async operations\n\nExample:\n```python\nrest_client = service_client.create_rest_client()\ntraining_run = rest_client.get_training_run(\"run-id\").result()\nprint(f\"Training Run: {training_run.training_run_id}, LoRA: {training_run.is_lora}\")\ncheckpoints = rest_client.list_checkpoints(\"run-id\").result()\nprint(f\"Found {len(checkpoints.checkpoints)} checkpoints\")\nfor checkpoint in checkpoints.checkpoints:\n    print(f\"  {checkpoint.checkpoint_type}: {checkpoint.checkpoint_id}\")\n```\n\n#### `get_training_run`\n\n```python\ndef get_training_run(\n    training_run_id: types.ModelID,\n    access_scope: Literal[\"owned\", \"accessible\"] = \"owned\"\n) -> ConcurrentFuture[types.TrainingRun]\n```\n\nGet training run info.\n\nArgs:\n- `training_run_id`: The training run ID to get information for\n\nReturns:\n- A `Future` containing the training run information\n\nExample:\n```python\nfuture = rest_client.get_training_run(\"run-id\")\nresponse = future.result()\nprint(f\"Training Run ID: {response.training_run_id}, Base: {response.base_model}\")\n```\n\n#### `get_training_run_async`\n\n```python\nasync def get_training_run_async(\n    training_run_id: types.ModelID,\n    access_scope: Literal[\"owned\",\n                          \"accessible\"] = \"owned\") -> types.TrainingRun\n```\n\nAsync version of get_training_run.\n\n#### `get_training_run_by_tinker_path`\n\n```python\ndef get_training_run_by_tinker_path(\n    tinker_path: str,\n    access_scope: Literal[\"owned\", \"accessible\"] = \"owned\"\n) -> ConcurrentFuture[types.TrainingRun]\n```\n\nGet training run info.\n\nArgs:\n- `tinker_path`: The tinker path to the checkpoint\n\nReturns:\n- A `Future` containing the training run information\n\nExample:\n```python\nfuture = rest_client.get_training_run_by_tinker_path(\"tinker://run-id/weights/checkpoint-001\")\nresponse = future.result()\nprint(f\"Training Run ID: {response.training_run_id}, Base: {response.base_model}\")\n```\n\n#### `get_training_run_by_tinker_path_async`\n\n```python\nasync def get_training_run_by_tinker_path_async(\n    tinker_path: str,\n    access_scope: Literal[\"owned\",\n                          \"accessible\"] = \"owned\") -> types.TrainingRun\n```\n\nAsync version of get_training_run_by_tinker_path.\n\n#### `get_weights_info_by_tinker_path`\n\n```python\ndef get_weights_info_by_tinker_path(\n        tinker_path: str) -> APIFuture[types.WeightsInfoResponse]\n```\n\nGet checkpoint information from a tinker path.\n\nArgs:\n- `tinker_path`: The tinker path to the checkpoint\n\nReturns:\n- An `APIFuture` containing the checkpoint information. The future is awaitable.\n\nExample:\n```python\nfuture = rest_client.get_weights_info_by_tinker_path(\"tinker://run-id/weights/checkpoint-001\")\nresponse = future.result()  # or await future\nprint(f\"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}\")\n```\n\n#### `list_training_runs`\n\n```python\ndef list_training_runs(\n    limit: int = 20,\n    offset: int = 0,\n    access_scope: Literal[\"owned\", \"accessible\"] = \"owned\"\n) -> ConcurrentFuture[types.TrainingRunsResponse]\n```\n\nList training runs with pagination support.\n\nArgs:\n- `limit`: Maximum number of training runs to return (default 20)\n- `offset`: Offset for pagination (default 0)\n\nReturns:\n- A `Future` containing the `TrainingRunsResponse` with training runs and cursor info\n\nExample:\n```python\nfuture = rest_client.list_training_runs(limit=50)\nresponse = future.result()\nprint(f\"Found {len(response.training_runs)} training runs\")\nprint(f\"Total: {response.cursor.total_count}\")\n# Get next page\nnext_page = rest_client.list_training_runs(limit=50, offset=50)\n```\n\n#### `list_training_runs_async`\n\n```python\nasync def list_training_runs_async(\n    limit: int = 20,\n    offset: int = 0,\n    access_scope: Literal[\"owned\", \"accessible\"] = \"owned\"\n) -> types.TrainingRunsResponse\n```\n\nAsync version of list_training_runs.\n\n#### `list_checkpoints`\n\n```python\ndef list_checkpoints(\n    training_run_id: types.ModelID\n) -> ConcurrentFuture[types.CheckpointsListResponse]\n```\n\nList available checkpoints (both training and sampler).\n\nArgs:\n- `training_run_id`: The training run ID to list checkpoints for\n\nReturns:\n- A `Future` containing the `CheckpointsListResponse` with available checkpoints\n\nExample:\n```python\nfuture = rest_client.list_checkpoints(\"run-id\")\nresponse = future.result()\nfor checkpoint in response.checkpoints:\n    if checkpoint.checkpoint_type == \"training\":\n        print(f\"Training checkpoint: {checkpoint.checkpoint_id}\")\n    elif checkpoint.checkpoint_type == \"sampler\":\n        print(f\"Sampler checkpoint: {checkpoint.checkpoint_id}\")\n```\n\n#### `list_checkpoints_async`\n\n```python\nasync def list_checkpoints_async(\n        training_run_id: types.ModelID) -> types.CheckpointsListResponse\n```\n\nAsync version of list_checkpoints.\n\n#### `get_checkpoint_archive_url`\n\n```python\ndef get_checkpoint_archive_url(\n    training_run_id: types.ModelID, checkpoint_id: str\n) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]\n```\n\nGet signed URL to download checkpoint archive.\n\nArgs:\n- `training_run_id`: The training run ID to download weights for\n- `checkpoint_id`: The checkpoint ID to download\n\nReturns:\n- A `Future` containing the `CheckpointArchiveUrlResponse` with signed URL and expiration\n\nExample:\n```python\nfuture = rest_client.get_checkpoint_archive_url(\"run-id\", \"checkpoint-123\")\nresponse = future.result()\nprint(f\"Download URL: {response.url}\")\nprint(f\"Expires at: {response.expires_at}\")\n# Use the URL to download the archive with your preferred HTTP client\n```\n\n#### `get_checkpoint_archive_url_async`\n\n```python\nasync def get_checkpoint_archive_url_async(\n        training_run_id: types.ModelID,\n        checkpoint_id: str) -> types.CheckpointArchiveUrlResponse\n```\n\nAsync version of get_checkpoint_archive_url.\n\n#### `delete_checkpoint`\n\n```python\ndef delete_checkpoint(training_run_id: types.ModelID,\n                      checkpoint_id: str) -> ConcurrentFuture[None]\n```\n\nDelete a checkpoint for a training run.\n\n#### `delete_checkpoint_async`\n\n```python\nasync def delete_checkpoint_async(training_run_id: types.ModelID,\n                                  checkpoint_id: str) -> None\n```\n\nAsync version of delete_checkpoint.\n\n#### `delete_checkpoint_from_tinker_path`\n\n```python\ndef delete_checkpoint_from_tinker_path(\n        tinker_path: str) -> ConcurrentFuture[None]\n```\n\nDelete a checkpoint referenced by a tinker path.\n\n#### `delete_checkpoint_from_tinker_path_async`\n\n```python\nasync def delete_checkpoint_from_tinker_path_async(tinker_path: str) -> None\n```\n\nAsync version of delete_checkpoint_from_tinker_path.\n\n#### `get_checkpoint_archive_url_from_tinker_path`\n\n```python\ndef get_checkpoint_archive_url_from_tinker_path(\n        tinker_path: str\n) -> ConcurrentFuture[types.CheckpointArchiveUrlResponse]\n```\n\nGet signed URL to download checkpoint archive.\n\nArgs:\n- `tinker_path`: The tinker path to the checkpoint\n\nReturns:\n- A `Future` containing the `CheckpointArchiveUrlResponse` with signed URL and expiration\n\n#### `get_checkpoint_archive_url_from_tinker_path_async`\n\n```python\nasync def get_checkpoint_archive_url_from_tinker_path_async(\n        tinker_path: str) -> types.CheckpointArchiveUrlResponse\n```\n\nAsync version of get_checkpoint_archive_url_from_tinker_path.\n\n#### `publish_checkpoint_from_tinker_path`\n\n```python\ndef publish_checkpoint_from_tinker_path(\n        tinker_path: str) -> ConcurrentFuture[None]\n```\n\nPublish a checkpoint referenced by a tinker path to make it publicly accessible.\n\nOnly the exact owner of the training run can publish checkpoints.\nPublished checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path method.\n\nArgs:\n- `tinker_path`: The tinker path to the checkpoint (e.g., \"tinker://run-id/weights/0001\")\n\nReturns:\n- A `Future` that completes when the checkpoint is published\n\nRaises:\n    HTTPException: 400 if checkpoint identifier is invalid\n    HTTPException: 404 if checkpoint not found or user doesn't own the training run\n    HTTPException: 409 if checkpoint is already public\n    HTTPException: 500 if there's an error publishing the checkpoint\n\nExample:\n```python\nfuture = rest_client.publish_checkpoint_from_tinker_path(\"tinker://run-id/weights/0001\")\nfuture.result()  # Wait for completion\nprint(\"Checkpoint published successfully\")\n```\n\n#### `publish_checkpoint_from_tinker_path_async`\n\n```python\nasync def publish_checkpoint_from_tinker_path_async(tinker_path: str) -> None\n```\n\nAsync version of publish_checkpoint_from_tinker_path.\n\n#### `unpublish_checkpoint_from_tinker_path`\n\n```python\ndef unpublish_checkpoint_from_tinker_path(\n        tinker_path: str) -> ConcurrentFuture[None]\n```\n\nUnpublish a checkpoint referenced by a tinker path to make it private again.\n\nOnly the exact owner of the training run can unpublish checkpoints.\nThis reverses the effect of publishing a checkpoint.\n\nArgs:\n- `tinker_path`: The tinker path to the checkpoint (e.g., \"tinker://run-id/weights/0001\")\n\nReturns:\n- A `Future` that completes when the checkpoint is unpublished\n\nRaises:\n    HTTPException: 400 if checkpoint identifier is invalid\n    HTTPException: 404 if checkpoint not found or user doesn't own the training run\n    HTTPException: 409 if checkpoint is already private\n    HTTPException: 500 if there's an error unpublishing the checkpoint\n\nExample:\n```python\nfuture = rest_client.unpublish_checkpoint_from_tinker_path(\"tinker://run-id/weights/0001\")\nfuture.result()  # Wait for completion\nprint(\"Checkpoint unpublished successfully\")\n```\n\n#### `unpublish_checkpoint_from_tinker_path_async`\n\n```python\nasync def unpublish_checkpoint_from_tinker_path_async(\n        tinker_path: str) -> None\n```\n\nAsync version of unpublish_checkpoint_from_tinker_path.\n\n#### `set_checkpoint_ttl_from_tinker_path`\n\n```python\ndef set_checkpoint_ttl_from_tinker_path(\n        tinker_path: str, ttl_seconds: int | None) -> ConcurrentFuture[None]\n```\n\nSet or remove the TTL on a checkpoint referenced by a tinker path.\n\nIf ttl_seconds is provided, the checkpoint will expire after that many seconds from now.\nIf ttl_seconds is None, any existing expiration will be removed.\n\nArgs:\n- `tinker_path`: The tinker path to the checkpoint (e.g., \"tinker://run-id/weights/0001\")\n- `ttl_seconds`: Number of seconds until expiration, or None to remove TTL\n\nReturns:\n- A `Future` that completes when the TTL is set\n\nRaises:\n    HTTPException: 400 if checkpoint identifier is invalid or ttl_seconds <= 0\n    HTTPException: 404 if checkpoint not found or user doesn't own the training run\n    HTTPException: 500 if there's an error setting the TTL\n\nExample:\n```python\nfuture = rest_client.set_checkpoint_ttl_from_tinker_path(\"tinker://run-id/weights/0001\", 86400)\nfuture.result()  # Wait for completion\nprint(\"Checkpoint TTL set successfully\")\n```\n\n#### `set_checkpoint_ttl_from_tinker_path_async`\n\n```python\nasync def set_checkpoint_ttl_from_tinker_path_async(\n        tinker_path: str, ttl_seconds: int | None) -> None\n```\n\nAsync version of set_checkpoint_ttl_from_tinker_path.\n\n#### `list_user_checkpoints`\n\n```python\ndef list_user_checkpoints(\n        limit: int = 100,\n        offset: int = 0) -> ConcurrentFuture[types.CheckpointsListResponse]\n```\n\nList all checkpoints for the current user across all their training runs.\n\nThis method retrieves checkpoints from all training runs owned by the authenticated user,\nsorted by time (newest first). It supports pagination for efficiently handling large\nnumbers of checkpoints.\n\nArgs:\n- `limit`: Maximum number of checkpoints to return (default 100)\n- `offset`: Offset for pagination (default 0)\n\nReturns:\n- A `Future` containing the `CheckpointsListResponse` with checkpoints and cursor info\n\nExample:\n```python\nfuture = rest_client.list_user_checkpoints(limit=50)\nresponse = future.result()\nprint(f\"Found {len(response.checkpoints)} checkpoints\")\nprint(f\"Total: {response.cursor.total_count if response.cursor else 'Unknown'}\")\nfor checkpoint in response.checkpoints:\n    print(f\"  {checkpoint.training_run_id}/{checkpoint.checkpoint_id}\")\n# Get next page if there are more checkpoints\nif response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count:\n    next_page = rest_client.list_user_checkpoints(limit=50, offset=50)\n```\n\n#### `list_user_checkpoints_async`\n\n```python\nasync def list_user_checkpoints_async(limit: int = 100,\n                                      offset: int = 0\n                                      ) -> types.CheckpointsListResponse\n```\n\nAsync version of list_user_checkpoints.\n\n#### `get_session`\n\n```python\ndef get_session(\n    session_id: str,\n    access_scope: Literal[\"owned\", \"accessible\"] = \"owned\"\n) -> ConcurrentFuture[types.GetSessionResponse]\n```\n\nGet session information including all training runs and samplers.\n\nArgs:\n- `session_id`: The session ID to get information for\n\nReturns:\n- A `Future` containing the `GetSessionResponse` with training_run_ids and sampler_ids\n\nExample:\n```python\nfuture = rest_client.get_session(\"session-id\")\nresponse = future.result()\nprint(f\"Training runs: {len(response.training_run_ids)}\")\nprint(f\"Samplers: {len(response.sampler_ids)}\")\n```\n\n#### `get_session_async`\n\n```python\nasync def get_session_async(\n    session_id: str,\n    access_scope: Literal[\"owned\", \"accessible\"] = \"owned\"\n) -> types.GetSessionResponse\n```\n\nAsync version of get_session.\n\n#### `list_sessions`\n\n```python\ndef list_sessions(\n    limit: int = 20,\n    offset: int = 0,\n    access_scope: Literal[\"owned\", \"accessible\"] = \"owned\"\n) -> ConcurrentFuture[types.ListSessionsResponse]\n```\n\nList sessions with pagination support.\n\nArgs:\n- `limit`: Maximum number of sessions to return (default 20)\n- `offset`: Offset for pagination (default 0)\n\nReturns:\n- A `Future` containing the `ListSessionsResponse` with list of session IDs\n\nExample:\n```python\nfuture = rest_client.list_sessions(limit=50)\nresponse = future.result()\nprint(f\"Found {len(response.sessions)} sessions\")\n# Get next page\nnext_page = rest_client.list_sessions(limit=50, offset=50)\n```\n\n#### `list_sessions_async`\n\n```python\nasync def list_sessions_async(\n    limit: int = 20,\n    offset: int = 0,\n    access_scope: Literal[\"owned\", \"accessible\"] = \"owned\"\n) -> types.ListSessionsResponse\n```\n\nAsync version of list_sessions.\n\n#### `get_sampler`\n\n```python\ndef get_sampler(sampler_id: str) -> APIFuture[types.GetSamplerResponse]\n```\n\nGet sampler information.\n\nArgs:\n- `sampler_id`: The sampler ID (sampling_session_id) to get information for\n\nReturns:\n- An `APIFuture` containing the `GetSamplerResponse` with sampler details\n\nExample:\n```python\n# Sync usage\nfuture = rest_client.get_sampler(\"session-id:sample:0\")\nresponse = future.result()\nprint(f\"Base model: {response.base_model}\")\nprint(f\"Model path: {response.model_path}\")\n\n# Async usage\nresponse = await rest_client.get_sampler(\"session-id:sample:0\")\nprint(f\"Base model: {response.base_model}\")\n```\n\n#### `get_sampler_async`\n\n```python\nasync def get_sampler_async(sampler_id: str) -> types.GetSamplerResponse\n```\n\nAsync version of get_sampler.\n"
  },
  {
    "path": "docs/api-reference/samplingclient.md",
    "content": "SamplingClient for Tinker API.\n\n## `SamplingClient` Objects\n\n```python\nclass SamplingClient(TelemetryProvider, QueueStateObserver)\n```\n\nClient for text generation and inference from trained or base models.\n\nThe SamplingClient lets you generate text tokens from either a base model or from weights\nyou've saved using a TrainingClient. You typically get one by calling\n`service_client.create_sampling_client()` or `training_client.save_weights_and_get_sampling_client()`.\n\nKey methods:\n- sample() - generate text completions with customizable parameters\n- compute_logprobs() - get log probabilities for prompt tokens\n\nCreate method parameters:\n- `model_path`: Path to saved model weights (starts with 'tinker://')\n- `base_model`: Name of base model to use for inference (e.g., 'Qwen/Qwen3-8B')\n- `retry_config`: Configuration for retrying failed requests\n\nExample:\n```python\nsampling_client = service_client.create_sampling_client(base_model=\"Qwen/Qwen3-8B\")\nprompt = types.ModelInput.from_ints(tokenizer.encode(\"The weather today is\"))\nparams = types.SamplingParams(max_tokens=20, temperature=0.7)\nfuture = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)\nresult = future.result()\n```\n\nMulti-processing support:\nThis class is picklable, so it can be passed to a separate process/worker to sample. It is also\nsafe to pass the same instance of SamplingClient to multiple processes/workers.\n\nIf you are using Tinker SDK with more than one process you should always create SamplingClient from\nthe main process and then pass it to the other processes/workers.\nServiceClient and TrainingClient should always be managed from the main process.\n\nSubprocess isolation:\nSet ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated\nsubprocess, preventing GIL contention from CPU-heavy user code (grading, environment\ninteractions) from stalling networking IO and heartbeats. This is transparent — the same\nAPI works with or without it.\n\n#### `sample`\n\n```python\ndef sample(\n        prompt: types.ModelInput,\n        num_samples: int,\n        sampling_params: types.SamplingParams,\n        include_prompt_logprobs: bool = False,\n        topk_prompt_logprobs: int = 0\n) -> ConcurrentFuture[types.SampleResponse]\n```\n\nGenerate text completions from the model.\n\nArgs:\n- `prompt`: The input tokens as ModelInput\n- `num_samples`: Number of independent samples to generate\n- `sampling_params`: Parameters controlling generation (temperature, max_tokens, etc.)\n- `include_prompt_logprobs`: Whether to include log probabilities for prompt tokens\n- `topk_prompt_logprobs`: Number of top token log probabilities to return per position\n\nReturns:\n- A `Future` containing the `SampleResponse` with generated text\n\nExample:\n```python\nprompt = types.ModelInput.from_ints(tokenizer.encode(\"The weather today is\"))\nparams = types.SamplingParams(max_tokens=20, temperature=0.7)\nfuture = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)\nresult = future.result()\nfor sample in result.samples:\n    print(tokenizer.decode(sample.tokens))\n```\n\n#### `sample_async`\n\n```python\nasync def sample_async(prompt: types.ModelInput,\n                       num_samples: int,\n                       sampling_params: types.SamplingParams,\n                       include_prompt_logprobs: bool = False,\n                       topk_prompt_logprobs: int = 0) -> types.SampleResponse\n```\n\nAsync version of sample.\n\n#### `compute_logprobs`\n\n```python\ndef compute_logprobs(\n        prompt: types.ModelInput) -> ConcurrentFuture[list[float | None]]\n```\n\nCompute log probabilities for prompt tokens.\n\nArgs:\n- `prompt`: The input tokens as ModelInput\n\nReturns:\n- A `Future` containing a list of log probabilities for each token in the prompt.\n    None values indicate tokens where log probabilities couldn't be computed.\n\nExample:\n```python\nprompt = types.ModelInput.from_ints(tokenizer.encode(\"Hello world\"))\nfuture = sampling_client.compute_logprobs(prompt)\nlogprobs = future.result()\nfor i, logprob in enumerate(logprobs):\n    if logprob is not None:\n        print(f\"Token {i}: logprob = {logprob:.4f}\")\n```\n\n#### `compute_logprobs_async`\n\n```python\nasync def compute_logprobs_async(\n        prompt: types.ModelInput) -> list[float | None]\n```\n\nAsync version of compute_logprobs.\n\n#### `get_tokenizer`\n\n```python\ndef get_tokenizer() -> PreTrainedTokenizer\n```\n\nGet the tokenizer for the current model.\n\nReturns:\n- `PreTrainedTokenizer` compatible with the model\n\n#### `get_base_model`\n\n```python\ndef get_base_model() -> str\n```\n\nGet the base model name for the current sampling session.\n\n#### `get_base_model_async`\n\n```python\nasync def get_base_model_async() -> str\n```\n\nAsync version of get_base_model.\n\n#### `__reduce__`\n\n```python\ndef __reduce__() -> tuple[Any, tuple[_SamplingClientPickleState]]\n```\n\nEnable pickling of SamplingClient for subprocess use.\n\nSerializes into a ``_SamplingClientPickleState`` dataclass. The\n``_sampling_client_sidecar_handle`` handle is deliberately omitted — only a\nbool flag is stored. The unpickled copy creates its own handle via\nthe per-process sidecar singleton. Do not add ``__getstate__``\nwithout preserving this behavior.\n"
  },
  {
    "path": "docs/api-reference/serviceclient.md",
    "content": "ServiceClient for Tinker API.\n\n## `ServiceClient` Objects\n\n```python\nclass ServiceClient(TelemetryProvider)\n```\n\nThe ServiceClient is the main entry point for the Tinker API. It provides methods to:\n- Query server capabilities and health status\n- Generate TrainingClient instances for model training workflows\n- Generate SamplingClient instances for text generation and inference\n- Generate RestClient instances for REST API operations like listing weights\n\nArgs:\n    user_metadata: Optional metadata attached to the created session.\n    project_id: Optional project ID to attach to the created session.\n    **kwargs: advanced options passed to the underlying HTTP client,\n             including API keys, headers, and connection settings.\n\nExample:\n```python\n# Near instant\nclient = ServiceClient()\n\n# Takes a moment as we initialize the model and assign resources\ntraining_client = client.create_lora_training_client(base_model=\"Qwen/Qwen3-8B\")\n\n# Near-instant\nsampling_client = client.create_sampling_client(base_model=\"Qwen/Qwen3-8B\")\n\n# Near-instant\nrest_client = client.create_rest_client()\n```\n\n#### `get_server_capabilities`\n\n```python\ndef get_server_capabilities() -> types.GetServerCapabilitiesResponse\n```\n\nQuery the server's supported features and capabilities.\n\nReturns:\n- `GetServerCapabilitiesResponse` with available models, features, and limits\n\nExample:\n```python\ncapabilities = service_client.get_server_capabilities()\nprint(f\"Supported models: {capabilities.supported_models}\")\nprint(f\"Max batch size: {capabilities.max_batch_size}\")\n```\n\n#### `get_server_capabilities_async`\n\n```python\nasync def get_server_capabilities_async(\n) -> types.GetServerCapabilitiesResponse\n```\n\nAsync version of get_server_capabilities.\n\n#### `create_lora_training_client`\n\n```python\ndef create_lora_training_client(\n        base_model: str,\n        rank: int = 32,\n        seed: int | None = None,\n        train_mlp: bool = True,\n        train_attn: bool = True,\n        train_unembed: bool = True,\n        user_metadata: dict[str, str] | None = None) -> TrainingClient\n```\n\nCreate a TrainingClient for LoRA fine-tuning.\n\nArgs:\n- `base_model`: Name of the base model to fine-tune (e.g., \"Qwen/Qwen3-8B\")\n- `rank`: LoRA rank controlling the size of adaptation matrices (default 32)\n- `seed`: Random seed for initialization. None means random seed.\n- `train_mlp`: Whether to train MLP layers (default True)\n- `train_attn`: Whether to train attention layers (default True)\n- `train_unembed`: Whether to train unembedding layers (default True)\n- `user_metadata`: Optional metadata to attach to the training run\n\nReturns:\n- `TrainingClient` configured for LoRA training\n\nExample:\n```python\ntraining_client = service_client.create_lora_training_client(\n    base_model=\"Qwen/Qwen3-8B\",\n    rank=16,\n    train_mlp=True,\n    train_attn=True\n)\n# Now use training_client.forward_backward() to train\n```\n\n#### `create_lora_training_client_async`\n\n```python\nasync def create_lora_training_client_async(\n        base_model: str,\n        rank: int = 32,\n        seed: int | None = None,\n        train_mlp: bool = True,\n        train_attn: bool = True,\n        train_unembed: bool = True,\n        user_metadata: dict[str, str] | None = None) -> TrainingClient\n```\n\nAsync version of create_lora_training_client.\n\n#### `create_training_client_from_state`\n\n```python\ndef create_training_client_from_state(\n        path: str,\n        user_metadata: dict[str, str] | None = None) -> TrainingClient\n```\n\nCreate a TrainingClient from saved model weights.\n\nThis loads only the model weights, not optimizer state. To also restore\noptimizer state (e.g., Adam momentum), use create_training_client_from_state_with_optimizer.\n\nArgs:\n- `path`: Tinker path to saved weights (e.g., \"tinker://run-id/weights/checkpoint-001\")\n- `user_metadata`: Optional metadata to attach to the new training run\n\nReturns:\n- `TrainingClient` loaded with the specified weights\n\nExample:\n```python\n# Resume training from a checkpoint (weights only, optimizer resets)\ntraining_client = service_client.create_training_client_from_state(\n    \"tinker://run-id/weights/checkpoint-001\"\n)\n# Continue training from the loaded state\n```\n\n#### `create_training_client_from_state_async`\n\n```python\nasync def create_training_client_from_state_async(\n        path: str,\n        user_metadata: dict[str, str] | None = None) -> TrainingClient\n```\n\nAsync version of create_training_client_from_state.\n\n#### `create_training_client_from_state_with_optimizer`\n\n```python\ndef create_training_client_from_state_with_optimizer(\n        path: str,\n        user_metadata: dict[str, str] | None = None) -> TrainingClient\n```\n\nCreate a TrainingClient from saved model weights and optimizer state.\n\nThis is similar to create_training_client_from_state but also restores\noptimizer state (e.g., Adam momentum), which is useful for resuming\ntraining exactly where it left off.\n\nArgs:\n- `path`: Tinker path to saved weights (e.g., \"tinker://run-id/weights/checkpoint-001\")\n- `user_metadata`: Optional metadata to attach to the new training run\n\nReturns:\n- `TrainingClient` loaded with the specified weights and optimizer state\n\nExample:\n```python\n# Resume training from a checkpoint with optimizer state\ntraining_client = service_client.create_training_client_from_state_with_optimizer(\n    \"tinker://run-id/weights/checkpoint-001\"\n)\n# Continue training with restored optimizer momentum\n```\n\n#### `create_training_client_from_state_with_optimizer_async`\n\n```python\nasync def create_training_client_from_state_with_optimizer_async(\n        path: str,\n        user_metadata: dict[str, str] | None = None) -> TrainingClient\n```\n\nAsync version of create_training_client_from_state_with_optimizer.\n\n#### `create_sampling_client`\n\n```python\ndef create_sampling_client(\n        model_path: str | None = None,\n        base_model: str | None = None,\n        retry_config: RetryConfig | None = None) -> SamplingClient\n```\n\nCreate a SamplingClient for text generation.\n\nArgs:\n- `model_path`: Path to saved model weights (e.g., \"tinker://run-id/weights/checkpoint-001\")\n- `base_model`: Name of base model to use (e.g., \"Qwen/Qwen3-8B\")\n- `retry_config`: Optional configuration for retrying failed requests\n\nReturns:\n- `SamplingClient` configured for text generation\n\nRaises:\n    ValueError: If neither model_path nor base_model is provided\n\nExample:\n```python\n# Use a base model\nsampling_client = service_client.create_sampling_client(\n    base_model=\"Qwen/Qwen3-8B\"\n)\n\n# Or use saved weights\nsampling_client = service_client.create_sampling_client(\n    model_path=\"tinker://run-id/weights/checkpoint-001\"\n)\n```\n\n#### `create_sampling_client_async`\n\n```python\nasync def create_sampling_client_async(\n        model_path: str | None = None,\n        base_model: str | None = None,\n        retry_config: RetryConfig | None = None) -> SamplingClient\n```\n\nAsync version of create_sampling_client.\n\n#### `create_rest_client`\n\n```python\ndef create_rest_client() -> RestClient\n```\n\nCreate a RestClient for REST API operations.\n\nThe RestClient provides access to various REST endpoints for querying\nmodel information, checkpoints, sessions, and managing checkpoint visibility.\n\nReturns:\n- `RestClient` for accessing REST API endpoints\n\nExample:\n```python\nrest_client = service_client.create_rest_client()\n\n# List checkpoints for a training run\ncheckpoints = rest_client.list_checkpoints(\"run-id\").result()\n\n# Get training run info\ntraining_run = rest_client.get_training_run(\"run-id\").result()\n\n# Publish a checkpoint\nrest_client.publish_checkpoint_from_tinker_path(\n    \"tinker://run-id/weights/checkpoint-001\"\n).result()\n```\n"
  },
  {
    "path": "docs/api-reference/trainingclient.md",
    "content": "TrainingClient for Tinker API.\n\n## `TrainingClient` Objects\n\n```python\nclass TrainingClient(TelemetryProvider)\n```\n\nClient for training ML models with forward/backward passes and optimization.\n\nThe TrainingClient corresponds to a fine-tuned model that you can train and sample from.\nYou typically get one by calling `service_client.create_lora_training_client()`.\nKey methods:\n- forward_backward() - compute gradients for training\n- optim_step() - update model parameters with Adam optimizer\n- save_weights_and_get_sampling_client() - export trained model for inference\n\nArgs:\n- `holder`: Internal client managing HTTP connections and async operations\n- `model_id`: Unique identifier for the model to train. Required for training operations.\n\nExample:\n```python\ntraining_client = service_client.create_lora_training_client(base_model=\"Qwen/Qwen3-8B\")\nfwdbwd_future = training_client.forward_backward(training_data, \"cross_entropy\")\noptim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))\nfwdbwd_result = fwdbwd_future.result()  # Wait for gradients\noptim_result = optim_future.result()    # Wait for parameter update\nsampling_client = training_client.save_weights_and_get_sampling_client(\"my-model\")\n```\n\n#### `forward`\n\n```python\ndef forward(\n    data: List[types.Datum],\n    loss_fn: types.LossFnType,\n    loss_fn_config: Dict[str, float] | None = None\n) -> APIFuture[types.ForwardBackwardOutput]\n```\n\nCompute forward pass without gradients.\n\nArgs:\n- `data`: List of training data samples\n- `loss_fn`: Loss function type (e.g., \"cross_entropy\")\n- `loss_fn_config`: Optional configuration for the loss function\n\nReturns:\n- `APIFuture` containing the forward pass outputs and loss\n\nExample:\n```python\ndata = [types.Datum(\n    model_input=types.ModelInput.from_ints(tokenizer.encode(\"Hello\")),\n    loss_fn_inputs={\"target_tokens\": types.ModelInput.from_ints(tokenizer.encode(\"world\"))}\n)]\nfuture = training_client.forward(data, \"cross_entropy\")\nresult = await future\nprint(f\"Loss: {result.loss}\")\n```\n\n#### `forward_async`\n\n```python\nasync def forward_async(\n    data: List[types.Datum],\n    loss_fn: types.LossFnType,\n    loss_fn_config: Dict[str, float] | None = None\n) -> APIFuture[types.ForwardBackwardOutput]\n```\n\nAsync version of forward.\n\n#### `forward_backward`\n\n```python\ndef forward_backward(\n    data: List[types.Datum],\n    loss_fn: types.LossFnType,\n    loss_fn_config: Dict[str, float] | None = None\n) -> APIFuture[types.ForwardBackwardOutput]\n```\n\nCompute forward pass and backward pass to calculate gradients.\n\nArgs:\n- `data`: List of training data samples\n- `loss_fn`: Loss function type (e.g., \"cross_entropy\")\n- `loss_fn_config`: Optional configuration for the loss function\n\nReturns:\n- `APIFuture` containing the forward/backward outputs, loss, and gradients\n\nExample:\n```python\ndata = [types.Datum(\n    model_input=types.ModelInput.from_ints(tokenizer.encode(\"Hello\")),\n    loss_fn_inputs={\"target_tokens\": types.ModelInput.from_ints(tokenizer.encode(\"world\"))}\n)]\n\n# Compute gradients\nfwdbwd_future = training_client.forward_backward(data, \"cross_entropy\")\n\n# Update parameters\noptim_future = training_client.optim_step(\n    types.AdamParams(learning_rate=1e-4)\n)\n\nfwdbwd_result = await fwdbwd_future\nprint(f\"Loss: {fwdbwd_result.loss}\")\n```\n\n#### `forward_backward_async`\n\n```python\nasync def forward_backward_async(\n    data: List[types.Datum],\n    loss_fn: types.LossFnType,\n    loss_fn_config: Dict[str, float] | None = None\n) -> APIFuture[types.ForwardBackwardOutput]\n```\n\nAsync version of forward_backward.\n\n#### `forward_backward_custom`\n\n```python\ndef forward_backward_custom(\n    data: List[types.Datum],\n    loss_fn: CustomLossFnV1,\n    *,\n    loss_type_input: Literal[\"logprobs\"] = \"logprobs\"\n) -> APIFuture[types.ForwardBackwardOutput]\n```\n\nCompute forward/backward with a custom loss function.\n\nAllows you to define custom loss functions that operate on log probabilities.\nThe custom function receives logprobs and computes loss and gradients.\n\nArgs:\n- `data`: List of training data samples\n- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)\n- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `\"logprobs\"`.\n\nReturns:\n- `APIFuture` containing the forward/backward outputs with custom loss\n\nExample:\n```python\ndef custom_loss(data, logprobs_list):\n    # Custom loss computation\n    loss = torch.mean(torch.stack([torch.mean(lp) for lp in logprobs_list]))\n    metrics = {\"custom_metric\": loss.item()}\n    return loss, metrics\n\nfuture = training_client.forward_backward_custom(data, custom_loss)\nresult = future.result()\nprint(f\"Custom loss: {result.loss}\")\nprint(f\"Metrics: {result.metrics}\")\n```\n\n#### `forward_backward_custom_async`\n\n```python\nasync def forward_backward_custom_async(\n    data: List[types.Datum],\n    loss_fn: CustomLossFnV1,\n    *,\n    loss_type_input: Literal[\"logprobs\"] = \"logprobs\"\n) -> APIFuture[types.ForwardBackwardOutput]\n```\n\nAsync version of forward_backward_custom.\n\n#### `optim_step`\n\n```python\ndef optim_step(\n        adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]\n```\n\nUpdate model parameters using Adam optimizer.\n\nThe Adam optimizer used by tinker is identical\nto [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).\nNote that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).\n\n\nArgs:\n- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)\n\nReturns:\n- `APIFuture` containing optimizer step response\n\nExample:\n```python\n# First compute gradients\nfwdbwd_future = training_client.forward_backward(data, \"cross_entropy\")\n\n# Then update parameters\noptim_future = training_client.optim_step(\n    types.AdamParams(\n        learning_rate=1e-4,\n        weight_decay=0.01\n    )\n)\n\n# Wait for both to complete\nfwdbwd_result = await fwdbwd_future\noptim_result = await optim_future\n```\n\n#### `optim_step_async`\n\n```python\nasync def optim_step_async(\n        adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]\n```\n\nAsync version of optim_step.\n\n#### `save_state`\n\n```python\ndef save_state(\n        name: str,\n        ttl_seconds: int | None = None\n) -> APIFuture[types.SaveWeightsResponse]\n```\n\nSave model weights to persistent storage.\n\nArgs:\n- `name`: Name for the saved checkpoint\n- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)\n\nReturns:\n- `APIFuture` containing the save response with checkpoint path\n\nExample:\n```python\n# Save after training\nsave_future = training_client.save_state(\"checkpoint-001\")\nresult = await save_future\nprint(f\"Saved to: {result.path}\")\n```\n\n#### `save_state_async`\n\n```python\nasync def save_state_async(\n        name: str,\n        ttl_seconds: int | None = None\n) -> APIFuture[types.SaveWeightsResponse]\n```\n\nAsync version of save_state.\n\n#### `load_state`\n\n```python\ndef load_state(path: str) -> APIFuture[types.LoadWeightsResponse]\n```\n\nLoad model weights from a saved checkpoint.\n\nThis loads only the model weights, not optimizer state (e.g., Adam momentum).\nTo also restore optimizer state, use load_state_with_optimizer.\n\nArgs:\n- `path`: Tinker path to saved weights (e.g., \"tinker://run-id/weights/checkpoint-001\")\n\nReturns:\n- `APIFuture` containing the load response\n\nExample:\n```python\n# Load checkpoint to continue training (weights only, optimizer resets)\nload_future = training_client.load_state(\"tinker://run-id/weights/checkpoint-001\")\nawait load_future\n# Continue training from loaded state\n```\n\n#### `load_state_async`\n\n```python\nasync def load_state_async(path: str) -> APIFuture[types.LoadWeightsResponse]\n```\n\nAsync version of load_state.\n\n#### `load_state_with_optimizer`\n\n```python\ndef load_state_with_optimizer(\n        path: str) -> APIFuture[types.LoadWeightsResponse]\n```\n\nLoad model weights and optimizer state from a checkpoint.\n\nArgs:\n- `path`: Tinker path to saved weights (e.g., \"tinker://run-id/weights/checkpoint-001\")\n\nReturns:\n- `APIFuture` containing the load response\n\nExample:\n```python\n# Resume training with optimizer state\nload_future = training_client.load_state_with_optimizer(\n    \"tinker://run-id/weights/checkpoint-001\"\n)\nawait load_future\n# Continue training with restored optimizer momentum\n```\n\n#### `load_state_with_optimizer_async`\n\n```python\nasync def load_state_with_optimizer_async(\n        path: str) -> APIFuture[types.LoadWeightsResponse]\n```\n\nAsync version of load_state_with_optimizer.\n\n#### `save_weights_for_sampler`\n\n```python\ndef save_weights_for_sampler(\n    name: str,\n    ttl_seconds: int | None = None\n) -> APIFuture[types.SaveWeightsForSamplerResponse]\n```\n\nSave model weights for use with a SamplingClient.\n\nArgs:\n- `name`: Name for the saved sampler weights\n- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)\n\nReturns:\n- `APIFuture` containing the save response with sampler path\n\nExample:\n```python\n# Save weights for inference\nsave_future = training_client.save_weights_for_sampler(\"sampler-001\")\nresult = await save_future\nprint(f\"Sampler weights saved to: {result.path}\")\n\n# Use the path to create a sampling client\nsampling_client = service_client.create_sampling_client(\n    model_path=result.path\n)\n```\n\n#### `save_weights_for_sampler_async`\n\n```python\nasync def save_weights_for_sampler_async(\n    name: str,\n    ttl_seconds: int | None = None\n) -> APIFuture[types.SaveWeightsForSamplerResponse]\n```\n\nAsync version of save_weights_for_sampler.\n\n#### `get_info`\n\n```python\ndef get_info() -> types.GetInfoResponse\n```\n\nGet information about the current model.\n\nReturns:\n- `GetInfoResponse` with model configuration and metadata\n\nExample:\n```python\ninfo = training_client.get_info()\nprint(f\"Model ID: {info.model_data.model_id}\")\nprint(f\"Base model: {info.model_data.model_name}\")\nprint(f\"LoRA rank: {info.model_data.lora_rank}\")\n```\n\n#### `get_info_async`\n\n```python\nasync def get_info_async() -> types.GetInfoResponse\n```\n\nAsync version of get_info.\n\n#### `get_tokenizer`\n\n```python\ndef get_tokenizer() -> PreTrainedTokenizer\n```\n\nGet the tokenizer for the current model.\n\nReturns:\n- `PreTrainedTokenizer` compatible with the model\n\nExample:\n```python\ntokenizer = training_client.get_tokenizer()\ntokens = tokenizer.encode(\"Hello world\")\ntext = tokenizer.decode(tokens)\n```\n\n#### `create_sampling_client`\n\n```python\ndef create_sampling_client(\n        model_path: str,\n        retry_config: RetryConfig | None = None) -> SamplingClient\n```\n\nCreate a SamplingClient from saved weights.\n\nArgs:\n- `model_path`: Tinker path to saved weights\n- `retry_config`: Optional configuration for retrying failed requests\n\nReturns:\n- `SamplingClient` configured with the specified weights\n\nExample:\n```python\nsampling_client = training_client.create_sampling_client(\n    \"tinker://run-id/weights/checkpoint-001\"\n)\n# Use sampling_client for inference\n```\n\n#### `create_sampling_client_async`\n\n```python\nasync def create_sampling_client_async(\n        model_path: str,\n        retry_config: RetryConfig | None = None) -> SamplingClient\n```\n\nAsync version of create_sampling_client.\n\n#### `save_weights_and_get_sampling_client`\n\n```python\ndef save_weights_and_get_sampling_client(\n        name: str | None = None,\n        retry_config: RetryConfig | None = None) -> SamplingClient\n```\n\nSave current weights and create a SamplingClient for inference.\n\nArgs:\n- `name`: Optional name for the saved weights (currently ignored for ephemeral saves)\n- `retry_config`: Optional configuration for retrying failed requests\n\nReturns:\n- `SamplingClient` configured with the current model weights\n\nExample:\n```python\n# After training, create a sampling client directly\nsampling_client = training_client.save_weights_and_get_sampling_client()\n\n# Now use it for inference\nprompt = types.ModelInput.from_ints(tokenizer.encode(\"Hello\"))\nparams = types.SamplingParams(max_tokens=20)\nresult = sampling_client.sample(prompt, 1, params).result()\n```\n\n#### `save_weights_and_get_sampling_client_async`\n\n```python\nasync def save_weights_and_get_sampling_client_async(\n        name: str | None = None,\n        retry_config: RetryConfig | None = None) -> SamplingClient\n```\n\nAsync version of save_weights_and_get_sampling_client.\n"
  },
  {
    "path": "docs/api-reference/types.md",
    "content": "## `LoadWeightsResponse` Objects\n\n```python\nclass LoadWeightsResponse(BaseModel)\n```\n\n#### `path`\n\nA tinker URI for model weights at a specific step\n\n## `WeightsInfoResponse` Objects\n\n```python\nclass WeightsInfoResponse(BaseModel)\n```\n\nMinimal information for loading public checkpoints.\n\n## `LoadWeightsRequest` Objects\n\n```python\nclass LoadWeightsRequest(StrictBase)\n```\n\n#### `path`\n\nA tinker URI for model weights at a specific step\n\n#### `optimizer`\n\nWhether to load optimizer state along with model weights\n\n## `CreateModelRequest` Objects\n\n```python\nclass CreateModelRequest(StrictBase)\n```\n\n#### `base_model`\n\nThe name of the base model to fine-tune (e.g., 'Qwen/Qwen3-8B').\n\n#### `user_metadata`\n\nOptional metadata about this model/training run, set by the end-user.\n\n#### `lora_config`\n\nLoRA configuration\n\n## `UnhandledExceptionEvent` Objects\n\n```python\nclass UnhandledExceptionEvent(BaseModel)\n```\n\n#### `event`\n\nTelemetry event type\n\n#### `severity`\n\nLog severity level\n\n#### `traceback`\n\nOptional Python traceback string\n\n## `Datum` Objects\n\n```python\nclass Datum(StrictBase)\n```\n\n#### `loss_fn_inputs`\n\nDictionary mapping field names to tensor data\n\n#### `convert_tensors`\n\n```python\ndef convert_tensors(cls, data: Any) -> Any\n```\n\nConvert torch.Tensor and numpy arrays to TensorData in loss_fn_inputs during construction.\n\n## `Checkpoint` Objects\n\n```python\nclass Checkpoint(BaseModel)\n```\n\n#### `checkpoint_id`\n\nThe checkpoint ID\n\n#### `checkpoint_type`\n\nThe type of checkpoint (training or sampler)\n\n#### `time`\n\nThe time when the checkpoint was created\n\n#### `tinker_path`\n\nThe tinker path to the checkpoint\n\n#### `size_bytes`\n\nThe size of the checkpoint in bytes\n\n#### `public`\n\nWhether the checkpoint is publicly accessible\n\n#### `expires_at`\n\nWhen this checkpoint expires (None = never expires)\n\n## `ParsedCheckpointTinkerPath` Objects\n\n```python\nclass ParsedCheckpointTinkerPath(BaseModel)\n```\n\n#### `tinker_path`\n\nThe tinker path to the checkpoint\n\n#### `training_run_id`\n\nThe training run ID\n\n#### `checkpoint_type`\n\nThe type of checkpoint (training or sampler)\n\n#### `checkpoint_id`\n\nThe checkpoint ID\n\n#### `from_tinker_path`\n\n```python\ndef from_tinker_path(cls, tinker_path: str) -> \"ParsedCheckpointTinkerPath\"\n```\n\nParse a tinker path to an instance of ParsedCheckpointTinkerPath\n\n## `SamplingParams` Objects\n\n```python\nclass SamplingParams(BaseModel)\n```\n\n#### `max_tokens`\n\nMaximum number of tokens to generate\n\n#### `seed`\n\nRandom seed for reproducible generation\n\n#### `stop`\n\nStop sequences for generation\n\n#### `temperature`\n\nSampling temperature\n\n#### `top_k`\n\nTop-k sampling parameter (-1 for no limit)\n\n#### `top_p`\n\nNucleus sampling probability\n\n## `SaveWeightsForSamplerRequest` Objects\n\n```python\nclass SaveWeightsForSamplerRequest(StrictBase)\n```\n\n#### `path`\n\nA file/directory name for the weights\n\n#### `ttl_seconds`\n\nTTL in seconds for this checkpoint (None = never expires)\n\n## `ModelInput` Objects\n\n```python\nclass ModelInput(StrictBase)\n```\n\n#### `chunks`\n\nSequence of input chunks (formerly TokenSequence)\n\n#### `from_ints`\n\n```python\ndef from_ints(cls, tokens: List[int]) -> \"ModelInput\"\n```\n\nCreate a ModelInput from a list of ints (tokens).\n\n#### `to_ints`\n\n```python\ndef to_ints() -> List[int]\n```\n\nConvert the ModelInput to a list of ints (tokens)\nThrows exception if there are any non-token chunks\n\n#### `length`\n\n```python\ndef length() -> int\n```\n\nReturn the total context length used by this ModelInput.\n\n#### `empty`\n\n```python\ndef empty(cls) -> \"ModelInput\"\n```\n\nCreate an empty ModelInput.\n\n#### `append`\n\n```python\ndef append(chunk: ModelInputChunk) -> \"ModelInput\"\n```\n\nAdd a new chunk, return a new ModelInput.\n\n#### `append_int`\n\n```python\ndef append_int(token: int) -> \"ModelInput\"\n```\n\nAdd a new token, return a new ModelInput.\n\n## `SessionEndEvent` Objects\n\n```python\nclass SessionEndEvent(BaseModel)\n```\n\n#### `duration`\n\nISO 8601 duration string\n\n#### `event`\n\nTelemetry event type\n\n#### `severity`\n\nLog severity level\n\n## `CreateSamplingSessionResponse` Objects\n\n```python\nclass CreateSamplingSessionResponse(BaseModel)\n```\n\n#### `sampling_session_id`\n\nThe generated sampling session ID\n\n## `CheckpointsListResponse` Objects\n\n```python\nclass CheckpointsListResponse(BaseModel)\n```\n\n#### `checkpoints`\n\nList of available model checkpoints for the model\n\n#### `cursor`\n\nPagination cursor information (None for unpaginated responses)\n\n## `SampleResponse` Objects\n\n```python\nclass SampleResponse(BaseModel)\n```\n\n#### `prompt_logprobs`\n\nIf prompt_logprobs was set to true in the request, logprobs are computed for\nevery token in the prompt. The `prompt_logprobs` response contains a float32\nvalue for every token in the prompt.\n\n#### `topk_prompt_logprobs`\n\nIf topk_prompt_logprobs was set to a positive integer k in the request,\nthe top-k logprobs are computed for every token in the prompt. The\n`topk_prompt_logprobs` response contains, for every token in the prompt,\na list of up to k (token_id, logprob) tuples.\n\n## `FutureRetrieveRequest` Objects\n\n```python\nclass FutureRetrieveRequest(StrictBase)\n```\n\n#### `request_id`\n\nThe ID of the request to retrieve\n\n#### `allow_metadata_only`\n\nWhen True, the server may return only response metadata (status and size)\ninstead of the full payload if the response exceeds the server's inline size limit.\n\n## `ForwardBackwardOutput` Objects\n\n```python\nclass ForwardBackwardOutput(BaseModel)\n```\n\n#### `loss_fn_output_type`\n\nThe class name of the loss function output records (e.g., 'TorchLossReturn', 'ArrayRecord').\n\n#### `loss_fn_outputs`\n\nDictionary mapping field names to tensor data\n\n#### `metrics`\n\nTraining metrics as key-value pairs.\n\nThe following metrics are recorded only during MoE (Mixture of Experts) training.\nNote: Don't fixate on the exact values of these metrics at the start of training.\nDifferent models on different data will have different initial values. How these\nmetrics evolve over training is what matters.\n\nIn the definitions below, *perfect balance* means ``total_tokens / num_experts``\n— the number of tokens each expert would receive if routing were perfectly uniform.\n\n- ``e_frac_with_tokens:mean``: Fraction of experts that received at least one token,\n  averaged across layers. A value of 1.0 means every expert got work; 0.5 means half\n  were idle. Decreasing over time is concerning (routing collapse).\n\n- ``e_frac_oversubscribed:mean``: Fraction of experts receiving more tokens than\n  perfect balance, averaged across layers. Increasing over time is concerning.\n\n- ``e_max_violation:mean``: How much the most overloaded expert exceeds perfect\n  balance, as a fraction of perfect balance, averaged across layers. Computed as\n  ``(max_tokens - perfect_balance) / perfect_balance``. A value of 2.0 means the\n  busiest expert got 3x the fair share. Increasing over time is concerning.\n\n- ``e_max_violation:max``: Same as ``e_max_violation:mean`` but takes the max\n  across layers instead of the mean. Shows the worst-case load imbalance in any\n  single layer.\n\n- ``e_min_violation:mean``: How much the least loaded expert is below perfect\n  balance, as a fraction of perfect balance, averaged across layers. Computed as\n  ``(min_tokens - perfect_balance) / perfect_balance``. A value of -0.5 means the\n  least-used expert got half the fair share; -1.0 means it got nothing. Typically\n  negative. Decreasing over time (more negative) is concerning.\n\n## `ModelData` Objects\n\n```python\nclass ModelData(BaseModel)\n```\n\nMetadata about a model's architecture and configuration.\n\n#### `arch`\n\nThe model architecture identifier.\n\n#### `model_name`\n\nThe human-readable model name.\n\n#### `tokenizer_id`\n\nThe identifier of the tokenizer used by this model.\n\n## `GetInfoResponse` Objects\n\n```python\nclass GetInfoResponse(BaseModel)\n```\n\nResponse containing information about a training client's model.\n\n#### `type`\n\nResponse type identifier.\n\n#### `model_data`\n\nDetailed metadata about the model.\n\n#### `model_id`\n\nUnique identifier for the model.\n\n#### `is_lora`\n\nWhether this is a LoRA fine-tuned model.\n\n#### `lora_rank`\n\nThe rank of the LoRA adaptation, if applicable.\n\n#### `model_name`\n\nThe name of the model.\n\n## `SaveWeightsResponse` Objects\n\n```python\nclass SaveWeightsResponse(BaseModel)\n```\n\n#### `path`\n\nA tinker URI for model weights at a specific step\n\n## `LoraConfig` Objects\n\n```python\nclass LoraConfig(StrictBase)\n```\n\n#### `rank`\n\nLoRA rank (dimension of low-rank matrices)\n\n#### `seed`\n\nSeed used for initialization of LoRA weights.\n\nUseful if you need deterministic or reproducible initialization of weights.\n\n#### `train_unembed`\n\nWhether to add lora to the unembedding layer\n\n#### `train_mlp`\n\nWhether to add loras to the MLP layers (including MoE layers)\n\n#### `train_attn`\n\nWhether to add loras to the attention layers\n\n## `SaveWeightsForSamplerResponseInternal` Objects\n\n```python\nclass SaveWeightsForSamplerResponseInternal(BaseModel)\n```\n\n#### `path`\n\nA tinker URI for model weights for sampling at a specific step\n\n#### `sampling_session_id`\n\nThe generated sampling session ID\n\n## `SaveWeightsForSamplerResponse` Objects\n\n```python\nclass SaveWeightsForSamplerResponse(BaseModel)\n```\n\n#### `path`\n\nA tinker URI for model weights for sampling at a specific step\n\n## `CreateSamplingSessionRequest` Objects\n\n```python\nclass CreateSamplingSessionRequest(StrictBase)\n```\n\n#### `session_id`\n\nThe session ID to create the sampling session within\n\n#### `sampling_session_seq_id`\n\nSequence ID for the sampling session within the session\n\n#### `base_model`\n\nOptional base model name to sample from.\n\nIs inferred from model_path, if provided. If sampling against a base model, this\nis required.\n\n#### `model_path`\n\nOptional tinker:// path to your model weights or LoRA weights.\n\nIf not provided, samples against the base model.\n\n## `OptimStepResponse` Objects\n\n```python\nclass OptimStepResponse(BaseModel)\n```\n\n#### `metrics`\n\nOptimization step metrics as key-value pairs\n\n## `SampleRequest` Objects\n\n```python\nclass SampleRequest(StrictBase)\n```\n\n#### `num_samples`\n\nNumber of samples to generate\n\n#### `base_model`\n\nOptional base model name to sample from.\n\nIs inferred from model_path, if provided. If sampling against a base model, this\nis required.\n\n#### `model_path`\n\nOptional tinker:// path to your model weights or LoRA weights.\n\nIf not provided, samples against the base model.\n\n#### `sampling_session_id`\n\nOptional sampling session ID to use instead of model_path/base_model.\n\nIf provided along with seq_id, the model configuration will be loaded from the\nsampling session. This is useful for multi-turn conversations.\n\n#### `seq_id`\n\nSequence ID within the sampling session.\n\nRequired when sampling_session_id is provided. Used to generate deterministic\nrequest IDs for the sampling request.\n\n#### `prompt_logprobs`\n\nIf set to `true`, computes and returns logprobs on the prompt tokens.\n\nDefaults to false.\n\n#### `topk_prompt_logprobs`\n\nIf set to a positive integer, returns the top-k logprobs for each prompt token.\n\n## `TrainingRun` Objects\n\n```python\nclass TrainingRun(BaseModel)\n```\n\n#### `training_run_id`\n\nThe unique identifier for the training run\n\n#### `base_model`\n\nThe base model name this model is derived from\n\n#### `model_owner`\n\nThe owner/creator of this model\n\n#### `is_lora`\n\nWhether this model uses LoRA (Low-Rank Adaptation)\n\n#### `corrupted`\n\nWhether the model is in a corrupted state\n\n#### `lora_rank`\n\nThe LoRA rank if this is a LoRA model, null otherwise\n\n#### `last_request_time`\n\nThe timestamp of the last request made to this model\n\n#### `last_checkpoint`\n\nThe most recent training checkpoint, if available\n\n#### `last_sampler_checkpoint`\n\nThe most recent sampler checkpoint, if available\n\n#### `user_metadata`\n\nOptional metadata about this training run, set by the end-user\n\n## `TelemetrySendRequest` Objects\n\n```python\nclass TelemetrySendRequest(StrictBase)\n```\n\n#### `platform`\n\nHost platform name\n\n#### `sdk_version`\n\nSDK version string\n\n## `CheckpointArchiveUrlResponse` Objects\n\n```python\nclass CheckpointArchiveUrlResponse(BaseModel)\n```\n\n#### `url`\n\nSigned URL to download the checkpoint archive\n\n#### `expires`\n\nUnix timestamp when the signed URL expires, if available\n\n## `SupportedModel` Objects\n\n```python\nclass SupportedModel(BaseModel)\n```\n\nInformation about a model supported by the server.\n\n#### `model_name`\n\nThe name of the supported model.\n\n## `GetServerCapabilitiesResponse` Objects\n\n```python\nclass GetServerCapabilitiesResponse(BaseModel)\n```\n\nResponse containing the server's supported models and capabilities.\n\n#### `supported_models`\n\nList of models available on the server.\n\n## `SessionStartEvent` Objects\n\n```python\nclass SessionStartEvent(BaseModel)\n```\n\n#### `event`\n\nTelemetry event type\n\n#### `severity`\n\nLog severity level\n\n## `GenericEvent` Objects\n\n```python\nclass GenericEvent(BaseModel)\n```\n\n#### `event`\n\nTelemetry event type\n\n#### `event_name`\n\nLow-cardinality event name\n\n#### `severity`\n\nLog severity level\n\n#### `event_data`\n\nArbitrary structured JSON payload\n\n## `TryAgainResponse` Objects\n\n```python\nclass TryAgainResponse(BaseModel)\n```\n\n#### `request_id`\n\nRequest ID that is still pending\n\n## `TrainingRunsResponse` Objects\n\n```python\nclass TrainingRunsResponse(BaseModel)\n```\n\n#### `training_runs`\n\nList of training runs\n\n#### `cursor`\n\nPagination cursor information\n\n## `ForwardBackwardInput` Objects\n\n```python\nclass ForwardBackwardInput(StrictBase)\n```\n\n#### `data`\n\nArray of input data for the forward/backward pass\n\n#### `loss_fn`\n\nFully qualified function path for the loss function\n\n#### `loss_fn_config`\n\nOptional configuration parameters for the loss function (e.g., PPO clip thresholds, DPO beta)\n\n## `ImageAssetPointerChunk` Objects\n\n```python\nclass ImageAssetPointerChunk(StrictBase)\n```\n\n#### `format`\n\nImage format\n\n#### `location`\n\nPath or URL to the image asset\n\n#### `expected_tokens`\n\nExpected number of tokens this image represents.\nThis is only advisory: the tinker backend will compute the number of tokens\nfrom the image, and we can fail requests quickly if the tokens does not\nmatch expected_tokens.\n\n## `TelemetryBatch` Objects\n\n```python\nclass TelemetryBatch(BaseModel)\n```\n\n#### `platform`\n\nHost platform name\n\n#### `sdk_version`\n\nSDK version string\n\n## `TensorData` Objects\n\n```python\nclass TensorData(StrictBase)\n```\n\n#### `data`\n\nFlattened tensor data as array of numbers.\n\n#### `shape`\n\nOptional.\n\nThe shape of the tensor (see PyTorch tensor.shape). The shape of a\none-dimensional list of length N is `(N,)`. Can usually be inferred if not\nprovided, and is generally inferred as a 1D tensor.\n\n#### `to_numpy`\n\n```python\ndef to_numpy() -> npt.NDArray[Any]\n```\n\nConvert TensorData to numpy array.\n\n#### `to_torch`\n\n```python\ndef to_torch() -> \"torch.Tensor\"\n```\n\nConvert TensorData to torch tensor.\n\n## `EncodedTextChunk` Objects\n\n```python\nclass EncodedTextChunk(StrictBase)\n```\n\n#### `tokens`\n\nArray of token IDs\n\n## `AdamParams` Objects\n\n```python\nclass AdamParams(StrictBase)\n```\n\n#### `learning_rate`\n\nLearning rate for the optimizer\n\n#### `beta1`\n\nCoefficient used for computing running averages of gradient\n\n#### `beta2`\n\nCoefficient used for computing running averages of gradient square\n\n#### `eps`\n\nTerm added to the denominator to improve numerical stability\n\n#### `weight_decay`\n\nWeight decay for the optimizer. Uses decoupled weight decay.\n\n#### `grad_clip_norm`\n\nMaximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping.\n\n## `ImageChunk` Objects\n\n```python\nclass ImageChunk(StrictBase)\n```\n\n#### `data`\n\nImage data as bytes\n\n#### `format`\n\nImage format\n\n#### `expected_tokens`\n\nExpected number of tokens this image represents.\nThis is only advisory: the tinker backend will compute the number of tokens\nfrom the image, and we can fail requests quickly if the tokens does not\nmatch expected_tokens.\n\n#### `validate_data`\n\n```python\ndef validate_data(cls, value: Union[bytes, str]) -> bytes\n```\n\nDeserialize base64 string to bytes if needed.\n\n#### `serialize_data`\n\n```python\ndef serialize_data(value: bytes) -> str\n```\n\nSerialize bytes to base64 string for JSON.\n\n## `SampledSequence` Objects\n\n```python\nclass SampledSequence(BaseModel)\n```\n\n#### `stop_reason`\n\nReason why sampling stopped\n\n#### `tokens`\n\nList of generated token IDs\n\n#### `logprobs`\n\nLog probabilities for each token (optional)\n\n## `Cursor` Objects\n\n```python\nclass Cursor(BaseModel)\n```\n\n#### `offset`\n\nThe offset used for pagination\n\n#### `limit`\n\nThe maximum number of items requested\n\n#### `total_count`\n\nThe total number of items available\n\n## `SaveWeightsRequest` Objects\n\n```python\nclass SaveWeightsRequest(StrictBase)\n```\n\n#### `path`\n\nA file/directory name for the weights\n\n#### `ttl_seconds`\n\nTTL in seconds for this checkpoint (None = never expires)\n"
  },
  {
    "path": "docs/async.mdx",
    "content": "# Async and Futures\n\n## Sync and Async APIs\n\nEvery method in the Tinker Python library has both a synchronous (sync) and an asynchronous (async) version. The async variants end with `_async`:\n\n| **Client** | **Sync method** | **Async method** |\n|---|---|---|\n| `ServiceClient` | `create_lora_training_client()` | `create_lora_training_client_async()` |\n| `TrainingClient` | `forward()` | `forward_async()` |\n| `SamplingClient` | `sample()` | `sample_async()` |\n| `RestClient` | `list_training_run_ids()` | `list_training_run_ids_async()` |\n\nTinker's `async` functionality requires an `asyncio` event loop, which you typically run like `asyncio.run(main())`.\n\n**When to use each:**\n\n- **Async:** Best for high-performance workflows where you need concurrency, especially when waiting on multiple network calls.\n- **Sync:** Simpler for scripts and learning examples. Easier to reason about but blocks on each operation.\n\nThe Tinker Cookbook generally uses `async` for implementations where performance is critical and sync for pedagogical examples.\n\n## Understanding Futures\n\nMost Tinker API methods are **non-blocking**, but may take a little while to run. They return immediately with a `Future` object that acknowledges that your request has been submitted. To get the actual result, you must explicitly wait:\n\n**Sync Python:**\n```python\nfuture = client.forward_backward(data, loss_fn)\nresult = future.result() # Blocks until complete\n```\n\n**Async Python (note the double await):**\n```python\nfuture = await client.forward_backward_async(data, loss_fn)\nresult = await future\n```\n\nAfter the first `await`, you're guaranteed that the request has been submitted, which ensures that it'll be ordered correctly relative to other requests. The second `await` waits for the actual computation to finish and returns the numerical outputs. For operations like `forward_backward`, the second `await` also guarantees that operation has been applied to the model---for `forward_backward`, this means that the gradients have been accumulated in the model's optimizer state.\n\n## Performance tips: overlap requests\n\nFor best performance, you should aim to submit your next request while the current one is running. Doing so is more important with Tinker than with other training systems because Tinker training runs on discrete [clock cycles](./under-the-hood#clock-cycles) (~10 seconds each). If you don't have a request queued when a cycle starts, you'll miss that cycle entirely.\n\n**Example pattern for overlapping forward_backward and optim_step:**\n```python\n# Submit forward_backward\nfwd_bwd_future = await client.forward_backward_async(batch, loss_fn)\n\n# Submit optim_step immediately (don't wait for forward_backward to finish)\noptim_future = await client.optim_step_async(adam_params)\n\n# Now retrieve results\nfwd_bwd_result = await fwd_bwd_future\noptim_result = await optim_future\n```\n\nThis pattern ensures both operations are queued and can be processed in the same [clock cycle](./under-the-hood#clock-cycles). In contrast, if you waited for `forward_backward` to complete before submitting `optim_step`, you might miss the next [clock cycle](./under-the-hood#clock-cycles).\n"
  },
  {
    "path": "docs/compatible-apis/openai.mdx",
    "content": "# OpenAI API Compatible Inference (in beta)\n\nOpenAI-compatible inference lets you interact with any model checkpoint in Tinker, using an endpoint compatible with the [OpenAI Completions API](https://platform.openai.com/docs/api-reference/chat). It’s designed to let you easily “poke at” your model while you're training it.\n\nFor inference within your training runs (e.g. RL), we recommend using Tinker’s standard [sampling client](/training-sampling).\n\nCurrently, OpenAI-compatible inference is meant for testing and internal use with low internal traffic, rather than large, high-throughput, user-facing deployments. Latency and throughput may vary by model and may change without notice during the beta. If you need higher or more stable throughput, contact the Tinker team in [our Discord](https://discord.gg/KqqEZNX88c) for guidance on larger-scale setups.\n\n## Use Cases\n\nOpenAI-compatible inference is designed for\n- **Fast feedback while training**: Start sampling very quickly from any sampler checkpoint obtained during training.\n- **Sampling while training continues**: Sample even while the training job is still running on that experiment.\n- **Developer &amp; internal workflows**: Intended for testing, evaluation, and internal tools.\n\nWe will release production-grade inference soon and will update our users then.\n\n## Using OpenAI compatible inference  from an OpenAI client\n\nThe new interface exposes an OpenAI-compatible HTTP API. You can use any OpenAI SDK or HTTP client that lets you override the base URL.\n\n1\\. Set the base URL of your OpenAI-compatible client to:\n\n```\nhttps://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1\n```\n\n2\\. Use a Tinker sampler weight path as the model name. For example:\n\n```\ntinker://0034d8c9-0a88-52a9-b2b7-bce7cb1e6fef:train:0/sampler_weights/000080\n```\n\nAny valid Tinker sampler checkpoint path works here. You can keep training and sample from the same checkpoint simultaneously.\n\n3\\. Authenticate with your Tinker API key, by passing the same key used for Tinker as the API key to the OpenAI client.\n\n**Note:** We support both `/completions` and `/chat/completions` endpoints. Chat requests are rendered with the model’s default Hugging Face chat template; if your checkpoint expects a different renderer, render the prompt yourself (see [Rendering](/rendering)) and use `/completions`.\n\n## Code Example\n\n```py\nfrom os import getenv\nfrom openai import OpenAI\n\nBASE_URL = \"https://tinker.thinkingmachines.dev/services/tinker-prod/oai/api/v1\"\nMODEL_PATH = \"tinker://0034d8c9-0a88-52a9-b2b7-bce7cb1e6fef:train:0/sampler_weights/000080\"\n\napi_key = getenv(\"TINKER_API_KEY\")\n\nclient = OpenAI(\n    base_url=BASE_URL,\n    api_key=api_key,\n)\n\nresponse = client.completions.create(\n    model=MODEL_PATH,\n    prompt=\"The capital of France is\",\n    max_tokens=50,\n    temperature=0.7,\n    top_p=0.9,\n)\n\nprint(f\"{response.choices[0].text}\")\n```\n\nNotes:\n\n* `BASE_URL` points to the OpenAI compatible inference endpoint.\n* `MODEL_PATH` is a sampler checkpoint path from Tinker (`tinker://0034d8c9-0a88-52a9-b2b7-bce7cb1e6fef:train:0/sampler_weights/000080`).\n* The rest of the arguments (`prompt`, `max_tokens`, `temperature`, `top_p`) behave like they do in the OpenAI Completions API.\n* You can swap `MODEL_PATH` to any other sampler checkpoint to compare runs quickly in your evals or notebooks.\n\n## Related docs\n\n* [Getting a `TINKER_API_KEY`](/install)\n\n* [Security and Privacy](https://thinkingmachines.ai/legal/terms/)\n\n* [Training and Sampling](/training-sampling)\n"
  },
  {
    "path": "docs/completers.mdx",
    "content": "import { CookbookLink } from '../components/CookbookLink'\n\n# Completers\n\nThe concept of policies is crucial to the RL training process. In the Tinker Cookbook, policies are implemented as `Completers`. Completers are abstractions that represent models or policies that can be sampled from, providing different levels of structure depending on your use case.\n\n## Overview of Completer Types\n\nThe Tinker Cookbook provides two main types of completers, each designed for different use cases:\n\n1. **TokenCompleter**: Operates on tokens and is used by RL algorithms\n2. **MessageCompleter**: Operates on messages and needs to be used with a renderer\n\nThe choice between these depends on whether you're working at the token level for RL training or at the message level for interacting with and evaluating the model.\n\n### TokenCompleter\n\nThe `TokenCompleter` is the foundational interface used by RL algorithms because they work directly with tokens.\n\n```python\nclass TokenCompleter:\n    async def __call__(\n        self, model_input: types.ModelInput, stop: StopCondition\n    ) -> TokensWithLogprobs:\n```\n\nThis interface takes:\n- `model_input`: The input to the model (of type `types.ModelInput`)\n- `stop`: Stop conditions, either a list of strings or token IDs (combined into a `StopCondition` class). When training with reinforcement learning, this should be defined by the `initial_observation` function of the environment.\n\nIt returns a `TokensWithLogprobs` object containing:\n- `tokens`: The generated token sequence\n- `maybe_logprobs`: Optional log probabilities for each token\n\n### MessageCompleter\n\nThe `MessageCompleter` operates at a higher level with structured messages, similarly to standard chat APIs. It takes a list of messages and returns a single assistant message response.\n\n```python\nclass MessageCompleter:\n    async def __call__(self, messages: list[renderers.Message]) -> renderers.Message:\n```\n\nFor training purposes the `TokenCompleter` is the class we will use for RL training as we need to optimize the same same set of tokens during the update step that the model output during rollout. The `MessageCompleter` is useful for sampling where we need to use the model output for semantic purposes such as Judge models or multi-agent environments.\n\nThe Tinker Cookbook uses two concrete implementations of these interfaces - <CookbookLink path=\"tinker_cookbook/completers.py\">`TinkerTokenCompleter`</CookbookLink> and <CookbookLink path=\"tinker_cookbook/completers.py\">`TinkerMessageCompleter`</CookbookLink> which are both wrappers around a `tinker.SamplingClient`. While the TinkerTokenCompleter operates directly on tokens, the TinkerMessageCompleter needs to be instantiated with a renderer to make it compatible with the inputs expected by the sampling client.\n"
  },
  {
    "path": "docs/dev-tips.mdx",
    "content": "# Developer Tips\n\n## AI-assisted development\n\nThis documentation is mirrored in the [docs folder of Tinker Cookbook](https://github.com/thinking-machines-lab/tinker-cookbook/tree/main/docs). The root of Tinker Cookbook also contains an `AGENTS.md` / `CLAUDE.md` file with a guide to the code and documentation. We recommend making a git clone of the cookbook and pointing your coding agent at the repository root.\n"
  },
  {
    "path": "docs/docs-outline.mdx",
    "content": "# Navigating these docs\n\nThese docs provide guides to both Tinker and the Tinker Cookbook.\n\nThe first half, \"Using the Tinker API\", walks you through the fundamentals of Tinker:\n\n- [Installation](./install) explains how to install both `tinker` and `tinker-cookbook`, and points you to the Tinker Console for your API key.\n- [Training and Sampling](./training-sampling) takes you through your first training run: setting up your training data, performing the run, and sampling from the model to test the run.\n- [Loss Functions](./losses) starts to get into the detail. Tinker supports a variety of built-in loss functions, but also allows you to use arbitrary differentiable loss functions.\n- [Saving and Loading](./save-load) explains the checkpoint types available in Tinker, and how to restart a run from a checkpoint.\n- [Async and Futures](./async) explains Tinker's `sync` and `async` API variants, and how Futures works as Tinker's requests structure.\n- [Model Lineup](./model-lineup) is regularly updated with the models available to fine-tune in Tinker.\n\nThe second half, \"The Tinker Cookbook\", provides recipes for how to use the Tinker API for research and applications. You are welcome to adapt these directly for your own use cases.\n\n- [Rendering](./rendering) explains how we convert from a conversation data structure to a list of tokens.\n- [Supervised Learning](./supervised-learning) explains basic SL and walks you through your first SL training loop. We make some suggestions for hyperparameter selection and detail how you can run your own hyperparameter sweep. We also show you how to perform prompt distillation.\n- [Reinforcement Learning](./rl) explains the basics of RL and walks you through your first RL run. We explain and provide code for creating your own RL environments and training on them. We provide a simple training loop for you to use and adapt, and explain RL hyperparameters and loss functions in detail.\n- [Preferences](./preferences) is a guide to learning from pairwise feedback, where  we have preference data indicating which of two completions is better for a given prompt. We walk you through two approaches to learning from pairwise preference data: direct preference optimization (DPO) and reinforcement learning from human feedback (RLHF).\n- [Evaluations](./evals) explains how you can use Tinker's outputs to run inline and offline evals on your runs.\n- [Completers](./completers) explains how Tinker implements policies, and provides two examples of how to use these in training.\n- [LoRA Primer](./lora-primer) explains the basic background of LoRA, and how to choose hyperparameters.\n"
  },
  {
    "path": "docs/download-weights.mdx",
    "content": "# Downloading weights\n\n### CLI\n\n```bash\ntinker checkpoint download $TINKER_CHECKPOINT_PATH\n```\n\nSee `tinker checkpoint download --help` for more details.\n\n### SDK\n\nYou can also download checkpoints using the SDK.\n\nExample:\n\n```python\nimport tinker\nimport urllib.request\n\nsc = tinker.ServiceClient()\nrc = sc.create_rest_client()\nfuture = rc.get_checkpoint_archive_url_from_tinker_path(\"tinker://<unique_id>/sampler_weights/final\")\ncheckpoint_archive_url_response = future.result()\n\n# `checkpoint_archive_url_response.url` is a signed URL that can be downloaded\n# until checkpoint_archive_url_response.expires\nurllib.request.urlretrieve(checkpoint_archive_url_response.url, \"archive.tar\")\n```\n\nReplace `<unique_id>` with your Training Run ID. This will save the LoRA adapter weights and config inside the `archive.tar` file.\n"
  },
  {
    "path": "docs/evals.mdx",
    "content": "import { Callout } from 'nextra/components'\nimport { CookbookLink } from '../components/CookbookLink'\n\n# Evaluations\n\nOur training scripts will print out training and test loss. Two common workflows for evaluations are to do inline evals during training and to do offline evals on various checkpoints from a run.\n\n## Inline Evals\n\nYou can add inline evaluations to your training runs by configuring evaluator builders in advance for both supervised fine-tuning and RL training jobs.\n\n### Supervised Fine-Tuning (`supervised.train`)\nAdd one or both of the following to your config:\n\n- **`evaluator_builders: list[EvaluatorBuilder]`** - Runs evaluations every `eval_every` steps\n- **`infrequent_evaluator_builders: list[EvaluatorBuilder]`** - Runs evaluations every `infrequent_eval_every` steps\n\n### RL Training (`rl.train`)\n\nAdd the following to your config:\n\n- **`evaluator_builders: list[SamplingClientEvaluator]`** - Runs evaluations every `eval_every` steps\n\nFor implementation guidance and a detailed example, see <CookbookLink path=\"tinker_cookbook/eval/evaluators.py\">here</CookbookLink> and\n <CookbookLink path=\"tinker_cookbook/eval/inspect_evaluators.py\">here</CookbookLink> respectively.\n\n\n## Offline evals\n\nWe support and recommend several ways for creating and running your offline evaluations on your model checkpoints.\n\n### Running Standard Evaluations with Inspect AI.\n\nWe support running many of the standard cited evaluations using the [Inspect AI library](https://github.com/UKGovernmentBEIS/inspect_ai).\n\nWe have provided a <CookbookLink path=\"tinker_cookbook/eval/run_inspect_evals.py\">script</CookbookLink> to evaluate models using Tinker's internal sampling functionality as shown below.\n\n```bash\nMODEL_PATH=tinker://FIXME # YOUR MODEL PATH HERE\npython -m tinker_cookbook.eval.run_inspect_evals \\\n    model_path=$MODEL_PATH \\\n    model_name=MODEL_NAME \\ # YOUR MODEL_NAME HERE\n    tasks=inspect_evals/ifeval,inspect_evals/mmlu_0_shot \\\n    renderer_name=RENDERER_NAME # YOUR RENDERER_NAME HERE\n```\n\nClick [here](https://github.com/UKGovernmentBEIS/inspect_ai/blob/main/docs/evals/listing.yml) to view additional supported evaluations.\n\n### Creating your own Sampling Evaluations\n\nWe recommend two ways to create your own evaluations:\n- creating your own tasks with Inspect AI and running like above\n- creating your own SamplingClientEvaluator\n\n#### Create tasks with Inspect AI\n\nIn addition to passing in standard evaluations, you can create your own tasks using inspect ai as detailed [here](https://inspect.aisi.org.uk/tasks.html).\n\nHere is a toy example of how to create an evaluation with an LLM-as-a-judge where we use a model produced by tinker as a grader.\n\n```python\nimport tinker\nfrom inspect_ai import Task, task\nfrom inspect_ai.dataset import MemoryDataset, Sample\nfrom inspect_ai.model import GenerateConfig as InspectAIGenerateConfig\nfrom inspect_ai.model import Model as InspectAIModel\nfrom inspect_ai.scorer import model_graded_qa\nfrom inspect_ai.solver import generate\nfrom tinker_cookbook.eval.inspect_utils import InspectAPIFromTinkerSampling\n\nQA_DATASET = MemoryDataset(\n    name=\"qa_dataset\",\n    samples=[\n        Sample(\n            input=\"What is the capital of France?\",\n            target=\"Paris\",\n        ),\n        Sample(\n            input=\"What is the capital of Italy?\",\n            target=\"Rome\",\n        ),\n    ],\n)\n\nservice_client = tinker.ServiceClient()\nsampling_client = service_client.create_sampling_client(\n    base_model=\"meta-llama/Llama-3.1-8B-Instruct\"\n)\n\napi = InspectAPIFromTinkerSampling(\n    renderer_name=\"llama3\",\n    model_name=\"meta-llama/Llama-3.1-8B-Instruct\",\n    sampling_client=sampling_client,\n    verbose=False,\n)\n\nGRADER_MODEL = InspectAIModel(api=api, config=InspectAIGenerateConfig())\n\n\n@task\ndef example_lm_as_judge() -> Task:\n    \"\"\"\n    Example task using LLM-as-a-judge scoring.\n\n    Note: The grader model defaults to the model being evaluated.\n    To use a different grader model, specify it with --model-grader when using inspect directly.\n    \"\"\"\n    return Task(\n        name=\"llm_as_judge\",\n        dataset=QA_DATASET,\n        solver=generate(),\n        scorer=model_graded_qa(\n            instructions=\"Grade strictly against the target text as general answer key and rubric. \"\n            \"Respond 'GRADE: C' if correct or 'GRADE: I' otherwise.\",\n            partial_credit=False,\n            # model parameter is optional - if not specified, uses the model being evaluated\n            model=GRADER_MODEL,\n        ),\n    )\n```\n\nInspect also natively supports replacing our `GRADER_MODEL` with any openai-chat-completion style api (e.g. openrouter).\n\n#### Create your own SamplingClientEvaluator\n\nAlternatively, you can create your own SamplingClientEvaluator class instead of using Inspect AI. This is a lower\nlevel abstraction than the above with finer-grain control over running your evaluations.\n\nWe expose this interface to allow users more control over their datasets and metrics. To illustrate, see this\n<CookbookLink path=\"tinker_cookbook/eval/custom_evaluators.py\">custom evaluators</CookbookLink> example of how one might create their own complex SamplingClientEvaluator.\n\nFor a more illustrative toy instructive example see below.\n\n```python\nfrom typing import Any, Callable\n\nimport tinker\nfrom tinker import types\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.evaluators import SamplingClientEvaluator\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nclass CustomEvaluator(SamplingClientEvaluator):\n    \"\"\"\n    A toy SamplingClientEvaluator that runs a custom evaluation and returns its metrics.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Any,\n        grader_fn: Callable[[str, str], bool],\n        model_name: str,\n        renderer_name: str,\n    ):\n        \"\"\"\n        Initialize the CustomEvaluator.\n        Args:\n            config: Configuration object containing all evaluation parameters\n        \"\"\"\n        self.dataset = dataset\n        self.grader_fn = grader_fn\n\n        tokenizer = get_tokenizer(model_name)\n        self.renderer = renderers.get_renderer(name=renderer_name, tokenizer=tokenizer)\n\n    async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:\n        \"\"\"\n        Run custom evaluation on the given sampling client and return metrics.\n        Args:\n            sampling_client: The sampling client to evaluate\n        Returns:\n            Dictionary of metrics from inspect evaluation\n        \"\"\"\n\n        metrics = {}\n\n        num_examples = len(self.dataset)\n        num_correct = 0\n\n        sampling_params = types.SamplingParams(\n            max_tokens=100,\n            temperature=0.7,\n            top_p=1.0,\n            stop=self.renderer.get_stop_sequences(),\n        )\n\n        for datum in self.dataset:\n            model_input: types.ModelInput = self.renderer.build_generation_prompt(\n                [renderers.Message(role=\"user\", content=datum[\"input\"])]\n            )\n            # Generate response\n            r: types.SampleResponse = await sampling_client.sample_async(\n                prompt=model_input, num_samples=1, sampling_params=sampling_params\n            )\n            tokens: list[int] = r.sequences[0].tokens\n            response: renderers.Message = self.renderer.parse_response(tokens)[0]\n            if self.grader_fn(response[\"content\"], datum[\"output\"]):\n                num_correct += 1\n\n        metrics[\"accuracy\"] = num_correct / num_examples\n        return metrics\n```\n\nHere is an example of how we can use the above CustomEvaluator on a toy dataset and grader.\n\n\n```python\nQA_DATASET = [\n    {\"input\": \"What is the capital of France?\", \"output\": \"Paris\"},\n    {\"input\": \"What is the capital of Germany?\", \"output\": \"Berlin\"},\n    {\"input\": \"What is the capital of Italy?\", \"output\": \"Rome\"},\n]\n\ndef grader_fn(response: str, target: str) -> bool:\n    return target.lower() in response.lower()\n\nevaluator = CustomEvaluator(\n    dataset=QA_DATASET,\n    grader_fn=grader_fn,\n    renderer_name=\"llama3\",\n    model_name=\"meta-llama/Llama-3.1-8B-Instruct\",\n\n)\n\nservice_client = tinker.ServiceClient()\nsampling_client = service_client.create_sampling_client(base_model=\"meta-llama/Llama-3.1-8B-Instruct\")\n\nasync def main():\n    result = await evaluator(sampling_client)\n    print(result)\n\nasyncio.run(main())\n```\n"
  },
  {
    "path": "docs/index.mdx",
    "content": "# Tinker: a training API for researchers and developers\n\nTinker lets you focus on what matters in LLM fine-tuning – your data and algorithms – while we handle the heavy lifting of distributed training.\n\nYou write a simple loop that runs on your CPU-only machine, including the data or environment and the loss function. We figure out how to make the training work on a bunch of GPUs, doing the exact computation you specified, efficiently. To change the model you're working with, you only need to change a single string in your code.\n\nTinker gives you full control over the training loop and all the algorithmic details. It's not a magic black box that makes fine-tuning \"easy\". It's a clean abstraction that shields you from the complexity of distributed training while preserving your control.\n\nHere's how the division of responsibilities works in practice:\n\n| **You focus on** | **You write** | **We handle** |\n|---|---|---|\n|  **Datasets and RL environments**<br />Your custom training data |  **Simple Python script**<br />Runs on your CPU |  **Efficient distributed training of large models**<br />Llama 70B, Qwen 235B |\n|  **Training logic**<br />Your loss functions, training loop, and evals |  **API calls**<br />`forward_backward()`<br />`optim_step()`<br />`sample()`<br />`save_state()` |  **Reliability**<br />Hardware failures handled transparently |\n\n## Features\n\nWhat the Tinker service currently supports:\n\n- Tinker lets you fine-tune open-weight models like the Qwen and Llama series, including large mixture-of-experts models like Qwen3-235B-A22B.\n- Tinker supports vision-language models (VLMs) like Qwen3-VL for image understanding tasks. See [Vision Inputs](/rendering#vision-inputs) for details.\n- Tinker implements low-rank adaptation (LoRA) fine-tuning, not full fine-tuning. However, we believe that LoRA gives the same performance as full fine-tuning for many important use cases, especially in RL (see [LoRA Without Regret](https://thinkingmachines.ai/blog/lora/)).\n- You can download the weights of your trained model to use outside of Tinker, for example with your inference provider of choice.\n\n## A quick look at functionality\n\nTinker's main functionality is contained in a few key functions:\n\n- `forward_backward`: feed in your data and loss function, and we'll compute and accumulate the gradients for you.\n- `optim_step`: update your model using the accumulated gradients\n- `sample`: Generate outputs from your trained model\n- other functions for saving and loading weights and optimizer state\n\n## What's next?\n\nSome features we expect to support in the future:\n\n- Full fine-tuning\n"
  },
  {
    "path": "docs/install.mdx",
    "content": "# Installing Tinker\n\nInstall the Tinker SDK with:\n\n```bash\npip install tinker\n```\n\nInstallation makes two components available: the python SDK and the tinker CLI.\n\n#### Python SDK\n\nThe python SDK provides low-level operations like `forward_backward`, `sample`, `optim_step`, and `save_state`.\n\n#### Tinker CLI\n\nThe tinker CLI is available as `tinker` or through `python -m tinker`. The CLI provides management functionality similar to that of the web console.\n\nRun `tinker --help` to see which functionality is available.\n\n## Tinker Cookbook\n\nWe also release [tinker-cookbook](https://github.com/thinking-machines-lab/tinker-cookbook), which is a collection of training code and experiment tools built on top of Tinker.\nFor the Cookbook, we'd recommend doing a local editable install, as you'll probably want to browse and edit the code:\n\n```bash\ngit clone https://github.com/thinking-machines-lab/tinker-cookbook.git\ncd tinker-cookbook\n# Switch to your virtual environment\npip install -e .\n```\n\n## Getting an API key\n\nCreate an API key from the [console](https://tinker-console.thinkingmachines.ai). You'll then want to set the `TINKER_API_KEY` environment variable to your newly generated API key.\n"
  },
  {
    "path": "docs/lora-primer.mdx",
    "content": "# LoRA Primer\n\nTinker supports [LoRA fine-tuning](https://arxiv.org/abs/2106.09685), which adjusts a small number of parameters, rather than full fine-tuning, which adjusts all of the parameters of the original model.\n\nOur current understanding is that LoRA has equivalent performance to full fine-tuning when doing RL or doing SL on small datasets, while it has worse performance on larger datasets. In more detail:\n\n- For supervised fine-tuning on small-to-medium-sized instruction-tuning and reasoning datasets, LoRA performs the same as full fine-tuning.\n- For datasets that exceed LoRA capacity, LoRA underperforms FullFT. Rather than the loss reaching a distinct floor that it can’t go below, LoRA results in worse training efficiency that depends on the relationship between model capacity to dataset size.\n- In some scenarios, LoRA is less tolerant of large batch sizes than full fine-tuning — it pays a larger penalty in loss as batch size increases beyond some point. This penalty is not mitigated by increasing the LoRA rank; it is a property of the product-of-matrices parametrization, which has different training dynamics than optimizing the original weight matrix.\n- Even in small data settings, LoRA performs better when applied to all weight matrices, especially MLP and MoE layers. Attention-only LoRA underperforms even when we match the number of trainable parameters by using higher rank for attention-only LoRA.\n- LoRA performs equivalently to FullFT for reinforcement learning even with small ranks. We find that RL requires very low capacity, a result we anticipated based on information-theoretical arguments.\n\nSee [LoRA Without Regret](https://thinkingmachines.ai/blog/lora) for more details and experimental results.\n\n## Hyperparameters\n\nThe learning rate (LR) is usually the most important hyperparameter in your ML experiments.\n\n\nLoRA requires a much larger LR than full fine-tuning---typically 20-100x larger, depending on model size. People often mistakenly retain their full fine-tuning LR when they port their code to use LoRA, leading them to conclude that LoRA works poorly.\n\n**Calculate the correct LoRA learning rate:**\n\nWe've provided a utility that calculates the factor you should scale the full fine-tuning LR by to get the equivalent LoRA LR:\n\n```python\nfrom tinker_cookbook.hyperparam_utils import get_lora_lr_over_full_finetune_lr\n\nmodel_name = \"meta-llama/Llama-3.1-8B\"\nprint(get_lora_lr_over_full_finetune_lr(model_name))\n```\n\nNote that for `Llama-3.2-1B`, the factor is 32, while for `Llama-3.1-70B`, the factor is 128.\n\n## What is LoRA exactly?\n\nLoRA is short for Low-Rank Adaptation. Given that the original model has a weight matrix $W$, we replace it with a new weight matrix $W'=W + BA$, where $B$ and $A$ are low-rank matrices. If $W$ is an $n \\times n$ matrix, then $B$ and $A$ are $n \\times r$ and $r \\times n$ matrices, respectively, where $r$ is the rank of the low-rank approximation. The default $r$ used by tinker is $32$.\n\nThe fact that LoRA uses a low-rank approximation of weight matrices is not terribly important. We prefer to think of LoRA as just a random projection of the parameter space that happens to be efficient to implement. When training with RL or small SL datasets, we are only learning a small amount of information, and this reduced set of parameters is more than enough.\n\n\n## What rank to use?\n\nThe default rank used by tinker is $32$. However, if you're doing SL on a large dataset, you should use a larger rank. For supervised learning, as a very rough approximation, LoRA will give good results as long as the number of LoRA parameters is at least as large as the number of completion tokens (i.e., weight=1 tokens). You can calculate the number of LoRA parameters with the following utility:\n\n```python\nfrom tinker_cookbook.hyperparam_utils import get_lora_param_count\n\nmodel_name = \"meta-llama/Llama-3.1-8B\"\nprint(get_lora_param_count(model_name, lora_rank=32))\n```\n\nFor reinforcement learning, we've found that small ranks give equivalent performance to larger ranks and full fine-tuning.\n\nNote that conveniently, the optimal learning rate does *not* depend on the LoRA rank. In fact, you can verify that if you train with SL on different ranks (but with the same LR), you'll get exactly the same learning curves for the first few steps of training.\n"
  },
  {
    "path": "docs/losses.mdx",
    "content": "import { CookbookLink } from '../components/CookbookLink'\n\n# Loss functions in Tinker\n\nFor most use cases, you can use the Tinker API's built-in loss functions by passing in a string identifier to `forward_backward`, which supports cross-entropy and policy gradient objectives. When you need more control, `forward_backward_custom` enables arbitrary differentiable loss functions at the cost of an additional forward pass; we explain both approaches in this doc.\n\nWhen you call `forward_backward`, you specify a loss function using a string that selects from a predetermined set of options, comprising the most common losses used for language model training.\n- **Input:** `forward_backward` expects a certain set of input tensors, passed in via `datum.loss_fn_inputs`, which is a dict mapping `str` to either a numpy or torch tensor\n- **Output:** `forward_backward` returns a `ForwardBackwardOutput`, which has a set of output tensors in `fwd_bwd_result.loss_fn_outputs`\n\nFor an example of using `forward_backward`, see `rl/train.py` in the Cookbook:\n```python\nimport tinker\nimport torch\nfrom tinker import TensorData\n\n# Create training data with required inputs\ndatum = tinker.Datum(\n    model_input=input_tokens,\n    loss_fn_inputs={\n        \"target_tokens\": TensorData.from_torch(torch.tensor(target_tokens)),\n        \"logprobs\": TensorData.from_torch(torch.tensor(sampling_logprobs)),  # Reference logprobs\n        \"advantages\": TensorData.from_torch(torch.tensor(advantages)),\n    }\n)\n\n# Option 1: Use importance sampling REINFORCE\nfwd_bwd_result = await training_client.forward_backward_async(\n    [datum], loss_fn=\"importance_sampling\"\n)\n\n# Option 2: Use PPO with clipping\nfwd_bwd_result = await training_client.forward_backward_async(\n    [datum], loss_fn=\"ppo\"\n)\n```\n\n## Basic loss functions\n\nCurrently, the Tinker API supports `cross_entropy` (for supervised learning), `importance_sampling`, `ppo`, `cispo` and `dro` for RL. We denote the training model as $p_{\\theta}$, the sampling distribution as $q$, and advantages as $A$. Also, for notation simplicity we omit the query and denote the full model completion sequence of tokens as $x$.\n\nAll losses are applied at the token level and tensors below have shape `(N,)` where `N` is `model_input.length`. They can be provided as `numpy.ndarray` or `torch.Tensor`, and the return values will use the same tensor type.\n\n### Supervised learning: `cross_entropy`\n\nFor SL, we implement the standard cross-entropy loss (i.e., negative log-likelihood), which optimizes the policy $p_\\theta$ to maximize the log-probability of the tokens $x$:\n\n$$\n\\mathcal{L(\\theta)} = -\\mathbb{E}_x[\\log p_\\theta(x)]\n$$\n\nwhere `weights` is either 0 or 1, typically generated from `renderer.build_supervised_example()` which returns `(model_input, weights)` (i.e., to specify the desired assistant turns to train on).\n\nThis is implemented as:\n\n```python\n# Apply weights and compute elementwise loss\nelementwise_loss = -target_logprobs * weights\n# Apply sum reduction to get the total loss\nloss = elementwise_loss.sum()  # scalar\n```\n\n- **Input tensors:**\n  - `target_tokens: array[(N,), int]` - Target token IDs\n  - `weights: array[(N,), float]` - Token-level loss weights (typically from the renderer)\n- **Output tensors:**\n  - `logprobs: array[(N,), float]` - Log probabilities of predicted tokens\n- **Output diagnostics:**\n  - `loss:sum` (scalar) - Sum of weighted cross-entropy losses\n\n### Policy gradient: `importance_sampling`\n\nFor RL, we implement a common variant of the policy gradient objective, used in practical settings where the *learner policy* $p$ may differ from the *sampling policy* $q$, which is common due to, e.g., [non-determinism](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/). The issue is that if these policies differ, then the objective:\n\n$$\n\\mathcal{L}(\\theta) = \\mathbb{E}_{x\\sim p_\\theta}\\bigl[A(x)\\bigr]\n$$\n\nis not computed in an unbiased why due to $x \\sim q$ (sampler) not exactly matching the desired $x \\sim p_\\theta$ (learner). To correct the bias, we use a modified \"importance sampling\" objective:\n\n$$\n\\mathcal{L}_{\\text{IS}}(\\theta) = \\mathbb{E}_{x\\sim q}\\Bigl[\\frac{p_\\theta(x)}{q(x)}A(x)\\Bigr],\n$$\n\nwhich yields the correct expected reward. In the formula above:\n\n- $\\log p_\\theta(x)$ – `target_logprobs` is from the learner, on the forward part of the `forward_backward` pass.\n- $\\log q(x)$ – `sampling_logprobs` is from the sampler, recorded during sampling as a correction term.\n\nThis is implemented as:\n\n```python\n# Compute probability ratio\nprob_ratio = torch.exp(target_logprobs - sampling_logprobs)\n# Compute importance-weighted loss\nloss = -(prob_ratio * advantages).sum()\n```\n\n- **Input tensors:**\n  - `target_tokens: array[(N,), int]` - Target token IDs (from the sampler $q$)\n  - `logprobs: array[(N,), float]` - `sampling_logprobs` for the tokens\n  - `advantages: array[(N,), float]` - Advantage values for RL (positive to reinforce, negative to discourage)\n- **Output tensors:**\n  - `logprobs: array[(N,), float]` - `target_logprobs` for the tokens\n- **Output diagnostics:**\n  - `loss:sum` (scalar) - Sum of importance-weighted policy gradient losses $\\mathcal L_{\\text{IS}}$\n\n### Proximal Policy Optimization: `ppo`\n\nPPO ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) addresses issues with standard policy gradient methods by introducing a clipping objective that limits policy updates within a close neighborhood of the sampling distribution. This prevents updates that are too large in policy space, especially when taking multiple gradient steps on the same rollout distribution.\n\nThe objective clips the importance ratio $\\frac{p_\\theta(x)}{q(x)}$ to prevent large policy updates, where $p_\\theta$ is the learner policy and $q$ is the sampling policy. Note that the PPO clipping and loss computation is applied token-wise, computing the loss for each token independently.\n\nThe PPO clipping objective is:\n\n$$\n\\mathcal{L}_{\\text{CLIP}}(\\theta) = -\\mathbb{E}_{x \\sim q}\\left[\\text{clip}\\left(\\frac{p_\\theta(x)}{q(x)}, 1-\\epsilon_{\\text{low}}, 1+\\epsilon_{\\text{high}}\\right) \\cdot A(x)\\right]\n$$\n\nThe final PPO loss combines the clipped and unclipped objectives:\n\n$$\n\\mathcal{L}_{\\text{PPO}}(\\theta) = -\\mathbb{E}_{x \\sim q}\\left[\\min\\left(\\frac{p_\\theta(x)}{q(x)} \\cdot A(x), \\text{clip}\\left(\\frac{p_\\theta(x)}{q(x)}, 1-\\epsilon_{\\text{low}}, 1+\\epsilon_{\\text{high}}\\right) \\cdot A(x)\\right)\\right]\n$$\n\nwhere $\\epsilon_{\\text{low}}$ and $\\epsilon_{\\text{high}}$ are hyperparameters (currently fixed to 0.2 in Tinker).\n\nThis is implemented as:\n\n```python\n# Compute probability ratio\nprob_ratio = torch.exp(target_logprobs - sampling_logprobs)\n# Apply clipping\nclipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)\n# Compute both objectives\nunclipped_objective = prob_ratio * advantages\nclipped_objective = clipped_ratio * advantages\n# Take minimum (most conservative)\nppo_objective = torch.min(unclipped_objective, clipped_objective)\n# PPO loss is negative of objective\nloss = -ppo_objective.sum()\n```\n\n\n**Example with custom clipping thresholds:**\n```python\nfwd_bwd_result = await training_client.forward_backward_async(\n    data=data,\n    loss_fn=\"ppo\",\n    loss_fn_config={\"clip_low_threshold\": 0.9, \"clip_high_threshold\": 1.1}\n)\n```\n\n**Additional Notes:**\n- The loss formulation above is quite general, since the user can organize the data generation and advantage estimation in their own code. For example, the main RL training scripts in the Tinker Cookbook use group-based rollouts with per-group advantage centering similar to GRPO ([Shao et al., 2024](https://arxiv.org/abs/2402.03300)).\n- The functional implementations of REINFORCE and PPO do not use an additional KL term like the original GRPO work, which has been noted to be mathematically inconsistent ([Zhang et al., 2025](https://arxiv.org/abs/2505.17508); [Tang et al., 2025](https://arxiv.org/abs/2506.09477)). However, it is possible to include a KL regularization term as part of the reward, which is mathematically correct and we provide this option in our RL training <CookbookLink path=\"tinker_cookbook/rl/train.py\">code and examples</CookbookLink>  (consider the incorporate_kl_penalty function).\n- Notice that for all objectives we sum the token-level losses over the sequence length unlike some other loss implementations. If you would like to explore different aggregation schemes, you can include that in the advantage tensor computation.\n\n### Clipped Importance Sampling Policy Optimization: `cispo`\n\nCISPO ([Chen et al., 2024](https://arxiv.org/abs/2506.13585); [Khatri et al., 2024](https://arxiv.org/abs/2510.13786)) is a policy gradient method that uses a clipped importance ratio as a coefficient for the policy gradient. Unlike PPO which clips the objective directly, CISPO clips the ratio and uses it to weight the log probability. Mathematically the objective is:\nThe CISPO objective is:\n\n$$\n\\mathcal{L}_{\\text{CISPO}}(\\theta) = \\mathbb{E}_{x \\sim q}\\left[\\textbf{sg}\\left( \\text{clip}\\left(\\frac{p_\\theta(x)}{q(x)}, 1-\\epsilon_{\\text{low}}, 1+\\epsilon_{\\text{high}}\\right) \\right) \\cdot \\log p_\\theta(x) \\cdot A(x)\\right]\n$$\n\nThis is implemented as:\n\n```python\n# Compute probability ratio\nprob_ratio = torch.exp(target_logprobs - sampling_logprobs)\n# Apply clipping\nclipped_ratio = torch.clamp(prob_ratio, clip_low_threshold, clip_high_threshold)\n# Compute CISPO objective (detach the clipped ratio)\ncispo_objective = clipped_ratio.detach() * target_logprobs * advantages\n# CISPO loss is negative of objective\nloss = -cispo_objective.sum()\n```\n\n\nSimilarly to the PPO objective you can pass loss function parameters in the following way:\n\n```python\nfwd_bwd_result = await training_client.forward_backward_async(\n    data=data,\n    loss_fn=\"cispo\",\n    loss_fn_config={\"clip_low_threshold\": 0.8, \"clip_high_threshold\": 1.2}\n)\n```\n\n### Direct Reward Optimization: `dro`\n\nDRO ([Richemond et al., 2024](https://arxiv.org/abs/2405.19107); [Kimi Team et al., 2025](https://arxiv.org/abs/2501.12599)) is a general off-policy (and even offline) reinforcement learning method that uses a quadratic penalty term to constrain the policy update. Notice that this loss uses a dfferent (soft) formulation of the advantage estimation, which needs to be implemented on the client side.\nThe DRO objective is:\n\n$$\n\\mathcal{L}_{\\text{DRO}}(\\theta) = \\mathbb{E}_{x \\sim q}\\left[\\log p_\\theta(x) \\cdot A(x) - \\frac{1}{2}\\beta \\left(\\log \\frac{p_\\theta(x)}{q(x)}\\right)^2\\right]\n$$\n\n\nThis is implemented as:\n\n```python\n# Compute quadratic penalty term\nquadratic_term = (target_logprobs - sampling_logprobs) ** 2\n# Compute DRO objective\ndro_objective = target_logprobs * advantages - 0.5 * beta * quadratic_term\n# DRO loss is negative of objective\nloss = -dro_objective.sum()\n```\n\nAnd similarly to other objectives, can specify the loss hyper-parameter as:\n\n```python\nfwd_bwd_result = await training_client.forward_backward_async(\n    data=data,\n    loss_fn=\"dro\",\n    loss_fn_config={\"beta\": 0.05}\n)\n```\n\n## Flexible loss functions: `forward_backward_custom`\n\nFor use cases outside of the above, we've provided the more flexible (but slower) methods `forward_backward_custom` and `forward_backward_custom_async` to compute a more general class of loss functions.\n\n### Usage\n\nHere's a simple example of a custom loss function:\n\n```python\ndef logprob_squared_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:\n    loss = (logprobs ** 2).sum()\n    return loss, {\"logprob_squared_loss\": loss.item()}\n```\n\nYou can call this loss function with `forward_backward_custom` like:\n\n```python\nloss, metrics = training_client.forward_backward_custom(data, logprob_squared_loss)\n```\n\nYou can also define loss functions which operate on multiple sequences at a time. For example, a loss function that computes the variance across the sequences (although practically useless) can be implemented as:\n\n```python\ndef variance_loss(data: list[Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:\n    flat_logprobs = torch.cat(logprobs)\n    variance = torch.var(flat_logprobs)\n    return variance, {\"variance_loss\": variance.item()}\n```\n\nA more practical use case would be to compute a Bradley-Terry loss on pairwise comparison data -- a classic approach in RL from human feedback, as introduced and popularized by [Learning to Summarize](https://arxiv.org/abs/2009.01325). Similarly, we can also implement [Direct Preference Optimization](https://arxiv.org/abs/2305.18290), which also computes a loss involving pairs of sequences; see the [DPO guide](/preferences/dpo-guide) for more details.\n\nIf you're using a custom loss function that you think is generally useful, please let us know, and we'll add it to the list of built-in loss functions.\n\nWe detail the `async` version of methods in the [Async and Futures](./async) of these docs.\n\n### How `forward_backward_custom` works\n\nYou don't need to read the following section to use `forward_backward_custom`, but read on if you're curious about how it works under the hood. Tinker does NOT pickle your function or send it to the server. Instead, Tinker decomposes the gradient computation into a `forward` call followed by a `forward_backward` call on an appropriately designed weighted cross-entropy loss, which lets it compute exactly the right gradient.\n\nMathematically, this works as follows. First, consider the full nonlinear loss function:\n\n```python\nloss = compute_loss_from_logprobs(compute_target_logprobs(params))\n```\n\nWe construct a loss function that is linear in the logprobs, but has the same gradient with respect to params as the full nonlinear loss:\n\n```python\nlogprobs = compute_target_logprobs(params)\nsurrogate_loss = (logprobs * logprob_grads).sum()\n# where logprob_grads = dLoss/dLogprobs\n```\n\nHere's a diagram showing what happens under the hood, on the client and server.\n\n```mermaid\n%%{init: {'theme':'base', 'themeVariables': {'fontSize':'11px'}}}%%\nflowchart TD\n\n    subgraph Server [\"Server Side\"]\n        B[\"<b>Initial Forward Pass</b><br/><i>→ model outputs (logprobs)</i>\"]\n        I[\"<b>Forward-Backward Pass</b><br/>Linear loss: sum(outputs * grad_outputs)<br/><i>→ final gradients on weights</i>\"]\n    end\n\n    subgraph Client [\"Client Side\"]\n        A[\"<b>Prepare Data</b><br/>List of Datum objects<br/><i>→ training data</i>\"]\n        A1[\"<b>Call forward function</b><br/>\"]\n        C[\"<b>Compute Custom Loss on Model Outputs</b><br/>loss = custom_fn(logprobs)<br/><i>→ loss tensor</i>\"]\n        G[\"<b>Backward Pass</b><br/>Call loss.backward()<br/><i>→ grad_outputs</i> (dLoss/dLogprobs)\"]\n    end\n\n\n    A --> A1\n    A1 --> B\n    B --> C\n    C --> G\n    G --> I\n\n    classDef clientBox fill:#e1f5fe,stroke:#0277bd,color:#000\n    classDef serverBox fill:#f3e5f5,stroke:#7b1fa2,color:#000\n\n    class A,C,G clientBox\n    class B,I serverBox\n```\n\nSince `forward_backward_custom` does an extra forward pass, it requires 1.5x as many FLOPs as a single `forward_backward`. It'll also take up to 3x as long (wall time), due to implementation details of how `forward_backward` operations are scheduled.\n"
  },
  {
    "path": "docs/model-lineup.mdx",
    "content": "import { FilterableModelTable } from '../components/FilterableModelTable'\n\n# Available Models in Tinker\n\nThe table below shows the models that are currently available in Tinker. We plan to update this list as new models are released.\n\n## What model should I use?\n\n- In general, use MoE models, which are more cost effective than the dense models.\n- Use Base models only if you're doing research or are running the full post-training pipeline yourself\n- If you want to create a model that is good at a specific task or domain, use an existing post-trained model, and fine-tune it on your own data or environment.\n  - If you care about latency, use one of the Instruction models, which will start outputting tokens without a chain-of-thought.\n  - If you care about intelligence and robustness, use one of the Hybrid or Reasoning models, which can use long chain-of-thought.\n\n## Full Listing\n\n{/* To add or remove a model, edit the list below. The docs site renders it as an interactive filterable table. */}\n\nexport const models = [\n  { name: \"Qwen/Qwen3.5-397B-A17B\", type: \"Hybrid + Vision\", arch: \"MoE\", size: \"Large\" },\n  { name: \"Qwen/Qwen3.5-35B-A3B\", type: \"Hybrid + Vision\", arch: \"MoE\", size: \"Medium\" },\n  { name: \"Qwen/Qwen3.5-27B\", type: \"Hybrid + Vision\", arch: \"Dense\", size: \"Medium\" },\n  { name: \"Qwen/Qwen3.5-4B\", type: \"Hybrid + Vision\", arch: \"Dense\", size: \"Compact\" },\n  { name: \"Qwen/Qwen3-VL-235B-A22B-Instruct\", type: \"Vision\", arch: \"MoE\", size: \"Large\" },\n  { name: \"Qwen/Qwen3-VL-30B-A3B-Instruct\", type: \"Vision\", arch: \"MoE\", size: \"Medium\" },\n  { name: \"Qwen/Qwen3-235B-A22B-Instruct-2507\", type: \"Instruction\", arch: \"MoE\", size: \"Large\" },\n  { name: \"Qwen/Qwen3-30B-A3B-Instruct-2507\", type: \"Instruction\", arch: \"MoE\", size: \"Medium\" },\n  { name: \"Qwen/Qwen3-30B-A3B\", type: \"Hybrid\", arch: \"MoE\", size: \"Medium\" },\n  { name: \"Qwen/Qwen3-30B-A3B-Base\", type: \"Base\", arch: \"MoE\", size: \"Medium\" },\n  { name: \"Qwen/Qwen3-32B\", type: \"Hybrid\", arch: \"Dense\", size: \"Medium\" },\n  { name: \"Qwen/Qwen3-8B\", type: \"Hybrid\", arch: \"Dense\", size: \"Small\" },\n  { name: \"Qwen/Qwen3-8B-Base\", type: \"Base\", arch: \"Dense\", size: \"Small\" },\n  { name: \"Qwen/Qwen3-4B-Instruct-2507\", type: \"Instruction\", arch: \"Dense\", size: \"Compact\" },\n  { name: \"openai/gpt-oss-120b\", type: \"Reasoning\", arch: \"MoE\", size: \"Medium\" },\n  { name: \"openai/gpt-oss-20b\", type: \"Reasoning\", arch: \"MoE\", size: \"Small\" },\n  { name: \"deepseek-ai/DeepSeek-V3.1\", type: \"Hybrid\", arch: \"MoE\", size: \"Large\" },\n  { name: \"deepseek-ai/DeepSeek-V3.1-Base\", type: \"Base\", arch: \"MoE\", size: \"Large\" },\n  { name: \"meta-llama/Llama-3.1-70B\", type: \"Base\", arch: \"Dense\", size: \"Large\" },\n  { name: \"meta-llama/Llama-3.3-70B-Instruct\", type: \"Instruction\", arch: \"Dense\", size: \"Large\" },\n  { name: \"meta-llama/Llama-3.1-8B\", type: \"Base\", arch: \"Dense\", size: \"Small\" },\n  { name: \"meta-llama/Llama-3.1-8B-Instruct\", type: \"Instruction\", arch: \"Dense\", size: \"Small\" },\n  { name: \"meta-llama/Llama-3.2-3B\", type: \"Base\", arch: \"Dense\", size: \"Compact\" },\n  { name: \"meta-llama/Llama-3.2-1B\", type: \"Base\", arch: \"Dense\", size: \"Compact\" },\n  { name: \"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16\", type: \"Hybrid\", arch: \"MoE\", size: \"Large\" },\n  { name: \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", type: \"Hybrid\", arch: \"MoE\", size: \"Medium\" },\n  { name: \"moonshotai/Kimi-K2-Thinking\", type: \"Reasoning\", arch: \"MoE\", size: \"Large\" },\n  { name: \"moonshotai/Kimi-K2.5\", type: \"Reasoning + Vision\", arch: \"MoE\", size: \"Large\" },\n]\n\n<FilterableModelTable models={models} />\n\n## Legend\n\n### Training Types\n\n- **Base**: Foundation models trained on raw text data, suitable for post-training research and custom fine-tuning.\n- **Instruction**: Models fine-tuned for following instructions and chat, optimized for fast inference.\n- **Reasoning**: Models that always use chain-of-thought reasoning before their \"visible\" output that responds to the prompt.\n- **Hybrid**: Models that can operate in both thinking and non-thinking modes, where the non-thinking mode requires using a special renderer or argument that disables chain-of-thought.\n- **Vision**: Vision-language models (VLMs) that can process images alongside text. See [Vision Inputs](/rendering#vision-inputs) for usage.\n\n### Architecture\n\n- **Dense**: Standard transformer architecture with all parameters active\n- **MoE**: Mixture of Experts architecture with sparse activation\n\n### Model Sizes\n\n- **Compact**: 1B-4B parameters\n- **Small**: 8B parameters\n- **Medium**: 30B-32B parameters\n- **Large**: 70B+ parameters\n\nNote that the MoE models are much more cost effective than the dense models as their cost is proportional to the number of active parameters and not the total number of parameters.\n"
  },
  {
    "path": "docs/overview-building.mdx",
    "content": "# Overview: Tinker Cookbook\n\nThe next sections provide a variety of guides for how to use the Tinker API for research and applications.\n\nWe expect people to use Tinker in a few different ways:\n\n1. You want to define datasets and environments and plug them into existing training code from the Tinker Cookbook.\n2. You want to write your own training loops from scratch, starting with the basics.\n3. You want to understand the classes and other concepts in Tinker Cookbook so you can extend them to add new functionality.\n\nDifferent parts of the docs will be tailored to these different approaches.\n\nWe'll start with a couple of general pages that'll be relevant to almost all of the use cases:\n\n- [Rendering to Tokens](./rendering.mdx) -- how we convert from a conversation data structure to a list of tokens (a.k.a. chat templates).\n- [LoRA Primer](./lora-primer.mdx) -- basic background of LoRA, and how to choose hyperparameters. For most fine-tuning applications, LoRA will give results that are roughly the same as full fine-tuning, however, you need to use different learning rates.\n"
  },
  {
    "path": "docs/preferences/dpo-guide.mdx",
    "content": "import { Callout } from 'nextra/components'\nimport { CookbookLink } from '../../components/CookbookLink'\n\n# Direct Preference Optimization (DPO)\n\nDirect Preference Optimization (DPO) is a method for training language models to align with human preferences without requiring a separate reward model. Instead of using reinforcement learning with human feedback (RLHF), DPO directly optimizes the model to prefer chosen responses over rejected ones using a simple classification loss.\n\n## DPO Algorithm Details\n\nThe core DPO loss is computed as:\n\n$$\n\\mathcal{L}_{\\theta} = -\\mathbb{E}_{x, y_\\text{chosen}, y_\\text{rejected} \\sim \\mathcal{D}}\\left[\\log\\sigma\\left(\\beta\\log \\frac{\\pi_{\\theta}(y_\\text{chosen}|x)}{\\pi_{\\text{ref}}(y_\\text{chosen}|x)} - \\beta\\log \\frac{\\pi_{\\theta}(y_\\text{rejected}|x)}{\\pi_{\\text{ref}}(y_\\text{rejected}|x)}\\right)\\right]\n$$\n\nWhere:\n- $\\pi_{\\theta}$ is the current policy\n- $\\pi_{\\text{ref}}$ is the reference model (typically the initial model before DPO training)\n- $\\beta$ is the DPO beta parameter\n- Where $\\mathcal{D}$ is a dataset of prompts $x$, a chosen response $y_{\\text{chosen}}$ and a rejected response $y_{\\text{rejected}}$\n\nThis optimizes the classical constrained RLHF objective, where the reference model constrains deviation from the initial distribution.\n\n<Callout type=\"info\">\n**DPO vs RLHF**: DPO eliminates the need for a separate reward model by directly optimizing the policy to prefer chosen responses. This makes training simpler and computationally cheaper than classical RLHF.\n</Callout>\n\n\n## Running DPO Training\n\nThe implementation is in <CookbookLink path=\"tinker_cookbook/preference/train_dpo.py\">train_dpo.py</CookbookLink> with a CLI interface in <CookbookLink path=\"tinker_cookbook/recipes/preference/dpo/train.py\">train.py</CookbookLink>. You can run it from the command line:\n\n```bash\npython -m tinker_cookbook.recipes.preference.dpo.train \\\n    log_path=/tmp/dpo-hhh-experiment \\\n    model_name=meta-llama/Llama-3.2-1B \\\n    dataset=hhh \\\n    renderer_name=role_colon \\\n    learning_rate=1e-5 \\\n    dpo_beta=0.1\n```\n\n### Key Parameters\n\n- `log_relpath`: Directory where results and checkpoints are saved\n- `model_name`: Base model used as initialization and for the reference policy\n- `dataset`: Dataset name (`hhh`, `helpsteer3`, `ultrafeedback`)\n- `renderer_name`: How conversations are formatted (see [Rendering](../rendering.mdx))\n- `learning_rate`: Learning rate for optimization\n- `dpo_beta`: DPO beta parameter (controls the strength of preference learning)\n\n### Available Datasets\n\nThere are several pre-defined datasets:\n\n- **`hhh`**: Anthropic's Helpful-Harmless-Honest dataset\n- **`helpsteer3`**: NVIDIA's HelpSteer3 preference dataset\n- **`ultrafeedback`**: UltraFeedback binarized preferences dataset\n\nThese are implemented as `DPODatasetBuilder` classes and you can implement a custom dataset builder following the `tinker_cookbook.preference.preference_datasets` interface.\n\n## Training Process\n\nDuring training, you'll see output like this showing the DPO metrics:\n\n```\n                   Step 50\n┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓\n┃ Metric                         ┃ Value     ┃\n┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩\n│ accuracy                       │ 0.568627  │\n│ batch_time                     │ 27.953704 │\n│ chosen_reward                  │ 0.053621  │\n│ dpo_loss                       │ 0.683825  │\n│ learning_rate                  │ 0.000009  │\n│ margin                         │ 0.002147  │\n│ num_pairs                      │ 255       │\n│ num_tokens                     │ 112638    │\n│ progress                       │ 0.081210  │\n│ rejected_reward                │ 0.032152  │\n│ test/nll                       │ 1.871778  │\n└────────────────────────────────┴───────────┘\n```\n\nThe key metrics are:\n- **`dpo_loss`**: The DPO classification loss\n- **`accuracy`**: Accuracy of the implicit reward model evaluated on the preference dataset\n- **`margin`**: Average difference between chosen and rejected rewards\n- **`chosen_reward`/`rejected_reward`**: Average rewards for chosen/rejected responses\n\n## Evaluating DPO Models\n\nAfter training, you can evaluate your DPO model using the inspect evaluation framework:\n\n```bash\nMODEL_PATH=tinker://YOUR_MODEL_PATH_HERE\npython -m tinker_cookbook.eval.run_inspect_evals \\\n    model_path=$MODEL_PATH \\\n    model_name=meta-llama/Llama-3.2-1B \\\n    tasks=inspect_evals/ifeval \\\n    renderer_name=role_colon\n```\n\nThis will evaluate the model on various benchmarks to measure the impact of preference optimization.\n\n## Tips for DPO Training\n\n1. **Beta Parameter**: Start with `dpo_beta=0.1` and adjust based on your dataset.\n\n2. **Learning Rate**: Use a lower learning rate than supervised fine-tuning (typically 1e-5 to 1e-6).\n\n3. **Base Model**: The base model should already be in-distribution with the preference data. Either start with a light SFT phase or collect on-policy preferences. While training would still work, sharp distribution mis-match will create strange model behaviors.\n"
  },
  {
    "path": "docs/preferences/rlhf-example.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# Reinforcement Learning from Human Feedback\n\nWe've provided a script that shows how to run a standard pipeline for reinforcement learning from human feedback (RLHF) in <CookbookLink path=\"tinker_cookbook/recipes/preference/rlhf/rlhf_pipeline.py\">rlhf_pipeline.py</CookbookLink>.\n\n```bash\npython -m recipes.preference.rlhf.rlhf_pipeline\n```\n\n## Training the initial policy via supervised learning\n\nFirst, we train the policy on the [no_robots dataset](https://huggingface.co/datasets/HuggingFaceH4/no_robots) from Huggingface, which is a basic instruction following dataset with human-written answers, which was designed to match the methodology from [InstructGPT](https://arxiv.org/abs/2203.02155).\n\n\n## Training the preference model via supervised learning\n\nWe train the preference model on the [HHH dataset](https://huggingface.co/datasets/Anthropic/hh-rlhf) from Anthropic, which is a dataset of pairwise comparisons of completions. We train a model that sees a pair of completions, A and B, and outputs which one is preferred.\n\n## Training the policy via reinforcement learning\n\nTaking the initial policy, and the preference model we just trained, we can now train the policy via reinforcement learning. This RL is a form of self-play, where we use the preference model to grade match-ups between the policy and itself. In particular, for each prompt, we sample multiple completions, and use the preference model to grade all pairs of completions. We then give the policy a reward based on the win fraction.\n"
  },
  {
    "path": "docs/preferences.mdx",
    "content": "import { CookbookLink } from '../components/CookbookLink'\n\n# Preferences\n\n# Learning from Preferences\n\nIn this section, we focus on learning from **pairwise feedback**, where we have preference data indicating which of two completions is better for a given prompt. This kind of feedback is a natural fit for tasks where there's not a simple correctness criterion that can be computed programmatically. These preferences might be collected from human evaluators or generated by a model.\n\n## Two Approaches to Preference Learning\n\nWhen you have pairwise preference data, there are two main approaches:\n\n1. **Direct Preference Optimization (DPO)**: Directly update the policy to prefer chosen responses over rejected ones, without needing a separate reward model. This is simpler and computationally cheaper. See the [DPO Guide](/preferences/dpo-guide) for details.\n\n2. **Reinforcement Learning from Human Feedback (RLHF)**: Train a reward model on preference data, then use reinforcement learning to optimize the policy against this reward model. This two-stage approach provides more flexibility. See the [RLHF example](/preferences/rlhf-example) for details.\n"
  },
  {
    "path": "docs/publish-weights.mdx",
    "content": "# Publishing weights\n\nIf you've trained a model that you'd like to share with the community, you can\npublish any number of checkpoints you've previously saved.\n\nOnce published, your checkpoint can be loaded by any tinker user and used to\nfurther train a new model or be sampled against.\n\n### Publishing\n\n```bash\ntinker checkpoint publish $TINKER_CHECKPOINT_PATH\n```\n\nwhere `$TINKER_CHECKPOINT_PATH` is a checkpoint path in the form of `tinker://14bdf3a1-0b95-55c7-8659-5edb1bc870af:train:17/weights/checkpoint_id_to_publish`.\n\nYou may confirm your checkpoint is published by dumping the checkpoint info and checking the `Public` property:\n\n```bash\ntinker checkpoint info tinker://14bdf3a1-0b95-55c7-8659-5edb1bc870af/weights/checkpoint_id_to_publish\n                              Checkpoint: weights/checkpoint_id_to_publish\n┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n┃ Property        ┃ Value                                                                          ┃\n┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n│ Checkpoint ID   │ weights/checkpoint_id_to_publish                                               │\n│ Type            │ training                                                                       │\n│ Tinker Path     │ tinker://14bdf3a1-0b95-55c7-8659-5edb1bc870af/weights/checkpoint_id_to_publish │\n│ Size            │ 342.4 MB                                                                       │\n│ Public          │ No                                                                             │\n│ Created         │ 23 minutes ago                                                                 │\n│ Training Run ID │ 14bdf3a1-0b95-55c7-8659-5edb1bc870af                                           │\n└─────────────────┴────────────────────────────────────────────────────────────────────────────────┘\n```\n\n### Unpublishing\n\n```bash\ntinker checkpoint unpublish $TINKER_CHECKPOINT_PATH\n```\n\n### Loading public weights\n\nLoading public weights is exactly the same as loading a non-public one:\n\n```python\nckpt_path = ...\ntraining_client = service_client.create_training_client_from_state(ckpt_path)\n```\n"
  },
  {
    "path": "docs/rendering.mdx",
    "content": "import { CookbookLink } from \"../components/CookbookLink\";\n\n# Rendering to tokens\n\nRendering converts list-of-message datatypes into their token representations for model training and inference. While similar to [chat templates](https://huggingface.co/docs/transformers/en/chat_templating), Tinker's rendering system is designed for the full training lifecycle--not just inference--supporting supervised learning, reinforcement learning, and deployment.\n\n\n## The Renderer class\n\nThe Renderer class is the main interface used for rendering. It can be found in <CookbookLink path=\"tinker_cookbook/renderers/\">`tinker_cookbook/renderers/`</CookbookLink>.\n\n**Example conversation:**\n\n```python\nmessages =[\n    {'role': 'system', 'content': 'Answer concisely; at most one sentence per response'},\n    {'role': 'user', 'content': 'What is the longest-lived rodent species?'},\n    {'role': 'assistant', 'content': 'The naked mole rat, which can live over 30 years.'},\n    {'role': 'user', 'content': 'How do they live so long?'},\n    {'role': 'assistant', 'content': 'They evolved multiple protective mechanisms including special hyaluronic acid that prevents cancer, extremely stable proteins, and efficient DNA repair systems that work together to prevent aging.'}\n]\n```\n\nWe'll use this conversation throughout the examples below.\n\n## Inference: Generating messages\n\nOur model maps tokens to tokens, but with the renderer, it can map messages to messages. To sample messages from the model, we need to use three methods from the renderer:\n\n- `build_generation_prompt`\n- `get_stop_sequences`\n- `parse_response`\n\n`build_generation_prompt` converts a conversation into a prompt that we can use to sample from the assistant. This is used during reinforcement learning and at deployment time.\n\n**Example: Generate an alternative assistant response**\n\nLet's remove the last assistant message and call `build_generation_prompt` to get a prompt that we can use to sample an alternative response from the assistant:\n\n```python\nfrom tinker_cookbook import renderers, tokenizer_utils\ntokenizer = tokenizer_utils.get_tokenizer('Qwen/Qwen3-30B-A3B')\nrenderer = renderers.get_renderer('qwen3', tokenizer)\nprompt = renderer.build_generation_prompt(messages[:-1])\nprint(prompt)\nprint('-'*10)\nprint(tokenizer.decode(prompt.to_ints()))\n```\n\n**Output:**\n\n```\nModelInput(chunks=[EncodedTextChunk(tokens=[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 8948, 198, 16141, 3529, 285, 974, 26, 518, 1429, 825, 11652, 817, 2033, 151645, 198, 151644, 872, 198, 3838, 374, 279, 22032, 61854, 20589, 306, 9419, 30, 151645, 198, 151644, 77091, 198, 785, 19020, 34651, 11244, 11, 892, 646, 3887, 916, 220, 18, 15, 1635, 13, 151645, 198, 151644, 872, 198, 10234, 30, 151645, 198, 151644, 77091, 198], type='encoded_text')])\n----------\n<|im_start|>system\nAnswer concisely; at most one sentence per response<|im_end|>\n<|im_start|>user\nWhat is the longest-lived rodent species?<|im_end|>\n<|im_start|>assistant\nThe naked mole rat, which can live over 30 years.<|im_end|>\n<|im_start|>user\nHow do they live so long?<|im_end|>\n<|im_start|>assistant\n\n```\n\nYou can see that the prompt is a `ModelInput` object, which is a list of `EncodedTextChunk` objects (but contains different objects in multi-modal data).\n\n**Sampling and parsing the response:**\n\nGiven that we're providing messages as input, we probably want a message output, rather than a token output. For that, we can use `parse_response`.\n\n```python\nimport tinker\nfrom tinker.types import SamplingParams\nservice_client = tinker.ServiceClient()\nsampling_client = service_client.create_sampling_client(base_model='Qwen/Qwen3-30B-A3B')\nstop_sequences = renderer.get_stop_sequences()\nprint(f\"Stop sequences: {stop_sequences}\")\nsampling_params = SamplingParams(max_tokens=100, temperature=0.5, stop=stop_sequences)\noutput = sampling_client.sample(prompt, sampling_params=sampling_params, num_samples=1).result()\nprint(f\"Sampled tokens: {output.sequences[0].tokens}\")\nsampled_message, parse_success = renderer.parse_response(output.sequences[0].tokens)\nprint(f\"Sampled message: {sampled_message}\")\nprint(f\"Parse success: {parse_success}\")\n```\n\n**Output:**\n\n```\nStop sequences: [151645]\nSampled tokens: [45, 7741, 34651, 31410, 614, 4911, 76665, 11, 2670, 264, 7548, 11050, 22077, 1849, 323, 264, 1602, 3347, 40761, 4379, 11, 892, 16792, 311, 862, 57119, 13, 151645]\nSampled message: {'role': 'assistant', 'content': 'Naked mole rats have unique adaptations, including a highly efficient immune system and a very low metabolic rate, which contribute to their longevity.'}\nParse success: True\n```\n\nYou can see that there is one stop sequence, `151645`, which you can verify is the `<|im_end|>` token. The output is parsed successfully into a message.\n\n## Training: Supervised learning\n\nFor supervised learning (and some other algorithms like [DPO](/preferences/dpo-guide)), we need to distinguish between **prompt tokens** (context) and **completion tokens** (what the model should learn to generate). We want to provide a target assistant message, and the renderer needs to tell us which tokens are part of the prompt and completion.\n\nWe can use `build_supervised_example` to get a `ModelInput` and per-token loss weights:\n\n```python\nmodel_input, weights = renderer.build_supervised_example(messages)\n\nfrom tinker_cookbook.utils.format_colorized import format_colorized\nprint(format_colorized(model_input.to_ints(), weights, tokenizer))\n```\n\nWe get the following output:\n\n<div className=\"example\">\n  <span className=\"prompt\">\n    &lt;|im_start|&gt;system↵\n    <br />\n    Answer concisely; at most one sentence per response&lt;|im_end|&gt;↵\n    <br />\n    &lt;|im_start|&gt;user↵\n    <br />\n    What is the longest-lived rodent species?&lt;|im_end|&gt;↵\n    <br />\n    &lt;|im_start|&gt;assistant↵\n    <br />\n    The naked mole rat, which can live over 30 years.&lt;|im_end|&gt;↵\n    <br />\n    &lt;|im_start|&gt;user↵\n    <br />\n    How do they live so long?&lt;|im_end|&gt;↵\n    <br />\n    &lt;|im_start|&gt;assistant↵\n    <br />\n  </span>\n  <span className=\"completion\">\n    They evolved multiple protective mechanisms including special hyaluronic\n    acid that prevents cancer, extremely stable proteins, and efficient DNA\n    repair systems that work together to prevent aging.&lt;|im_end|&gt;\n    <br />\n  </span>\n</div>\nThe green text is part of the prompt (i.e. with `weight=0`, so no loss is computed\non these) and red is part of the completion (i.e. with `weight=1`, so the model is\ntrained to predict these). Note that the ↵ have been inserted for clarity to show\nnewlines; these are not actually part of the token sequence.\n\nThe key insight here is that only the final assistant message is treated as the completion. All previous context, including the first assistant response, is part of the prompt, so the model learns to continue conversations rather than just answer single questions.\n\n## Vision Inputs\n\nTinker supports vision-language models (VLMs) like `Qwen/Qwen3-VL-30B-A3B-Instruct` and `Qwen/Qwen3-VL-235B-A22B-Instruct`. For low-level `ImageChunk` usage, see [Vision inputs](/training-sampling#vision-inputs) in the Training and Sampling guide. This section covers the higher-level message abstractions.\n\n### Multimodal messages\n\nFor VLMs, message content can be either a string or a list of content parts:\n\n```python\nfrom tinker_cookbook.renderers import Message, TextPart, ImagePart\n\n# Text-only message (standard)\ntext_message = Message(role='user', content='What is this?')\n\n# Multimodal message with image\nmultimodal_message = Message(\n    role='user',\n    content=[\n        ImagePart(type='image', image='https://thinkingmachines.ai/blog/on-policy-distillation/images/chess.png'),\n        TextPart(type='text', text='What is in this image?'),\n    ]\n)\n```\n\nFor lower-level control using `ImageChunk` directly, see [Vision inputs](/training-sampling#vision-inputs) in the Training and Sampling guide.\n\n### Using Qwen3VLRenderer\n\nThe `Qwen3VLRenderer` and `Qwen3VLInstructRenderer` handle Qwen's vision special tokens (`<|vision_start|>`, `<|vision_end|>`) automatically:\n\n```python\nfrom tinker_cookbook import renderers, tokenizer_utils\nfrom tinker_cookbook.image_processing_utils import get_image_processor\n\nmodel_name = \"Qwen/Qwen3-VL-235B-A22B-Instruct\"\ntokenizer = tokenizer_utils.get_tokenizer(model_name)\nimage_processor = get_image_processor(model_name)\n\nrenderer = renderers.Qwen3VLInstructRenderer(tokenizer, image_processor)\n\nmessages = [\n    {\n        'role': 'user',\n        'content': [\n            {'type': 'image', 'image': 'https://thinkingmachines.ai/blog/on-policy-distillation/images/chess.png'},\n            {'type': 'text', 'text': 'What is in this image?'},\n        ]\n    }\n]\n\nprompt = renderer.build_generation_prompt(messages)\n```\n\nFor a complete example of training a VLM image classifier, see the <CookbookLink path=\"tinker_cookbook/recipes/vlm_classifier\">VLM Classifier recipe</CookbookLink> in the cookbook.\n\n## HuggingFace Compatibility\n\n**Important:** Tinker's default renderers are designed to produce **identical tokens** to HuggingFace's `apply_chat_template`. This is critical because:\n\n1. **The [OpenAI-compatible endpoint](/compatible-apis/openai)** (`/chat/completions`) uses HuggingFace chat templates to convert messages to tokens\n2. **If you train with a non-HF-compatible renderer**, your model may not work correctly with the OpenAI endpoint\n\nThe default renderers (`qwen3`, `llama3`, `deepseekv3`, etc.) match HF behavior. The exception is when using `strip_thinking_from_history=False` on thinking-enabled renderers (`Qwen3Renderer`, `DeepSeekV3ThinkingRenderer`)—this is a special mode for multi-turn RL efficiency that does NOT match HF (see [Sequence Extension](/rl/sequence-extension)).\n\n| Renderer                 | HF Equivalent                                              |\n| ------------------------ | ---------------------------------------------------------- |\n| `qwen3`                  | `apply_chat_template(..., enable_thinking=True)` (default) |\n| `qwen3_disable_thinking` | `apply_chat_template(..., enable_thinking=False)`          |\n| `llama3`                 | `apply_chat_template(...)` *                               |\n| `deepseekv3`             | `apply_chat_template(...)`                                 |\n\n\\* The Llama3 renderer omits the \"Cutting Knowledge Date...\" preamble that HF prepends to system messages. Add this manually if you need exact HF compatibility.\n\n**Recommendation:** If you plan to use the OpenAI endpoint for inference, always use the default renderers with default options.\n\n\n## Multi-turn RL and the Extension Property\n\nWhen using renderers in multi-turn RL, an important consideration is whether consecutive timesteps satisfy the **extension property**—where each observation is a prefix extension of the previous observation plus action. This affects compute efficiency (O(T) vs O(T^2)) and KV-cache reuse.\n\nSome renderers, like `Qwen3Renderer`, have options that affect this property. For example, `strip_thinking_from_history` controls whether `<think>` blocks are preserved in conversation history.\n\nSee the [Sequence Extension](/rl/sequence-extension) documentation for details on how this works and the tradeoffs involved.\n\n## Appendix: Why not Jinja templates?\n\nHuggingFace chat templates (Jinja2) are designed for a single use case: converting messages to tokens for inference. Tinker's renderer system handles the full training lifecycle, which requires capabilities that chat templates don't provide:\n\n1. **Supervised learning requires per-token loss weights.** Chat templates produce a flat token sequence; they can't distinguish which tokens are prompt (weight=0) vs completion (weight=1). The renderer's `build_supervised_example` method returns both tokens and weights.\n\n2. **Parsing model output back into messages.** After sampling, you need to convert tokens back into structured messages. This includes knowing the stop sequences (`get_stop_sequences`), extracting thinking blocks (`<think>...</think>`), and parsing tool calls. The `parse_response` method handles all of this, including graceful handling of malformed output.\n\n3. **Tool calling details vary by model.** Each model family has its own tool calling format (Qwen uses `<tool_call>` XML tags, DeepSeek uses special tokens, Llama3 doesn't support tool calling reliably). The renderers encode these formats correctly and parse tool calls from model output, including handling parse failures.\n\n4. **Precise tokenization control.** Tokenization can produce different results depending on how strings are split. For example, tokenizing `\"Hello\" + \" world\"` separately may differ from `\"Hello world\"`. The renderers directly construct a token sequence, rather than strings, so they give you precise control over the token sequence, which is important for ensuring train-test consistency and KV-cache friendliness.\n"
  },
  {
    "path": "docs/rl/rl-basic.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# Your First RL Run\n\nWe've provided a minimal script that runs RL on the [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k): <CookbookLink path=\"tinker_cookbook/recipes/rl_basic.py\">rl_basic.py</CookbookLink>. You can run the minimal RL script from the command line as follows:\n\n```bash\npython -m tinker_cookbook.recipes.rl_basic\n```\n\nThis script will fine-tune the Llama-3.1-8B base (pretrained) model on this dataset with the following reward function:\n\n$$\n1[\\text{answer is correct}] + 0.1 \\times (1[\\text{answer is formatted correctly}] - 1)\n$$\n\nThe training should take about 1 minute per iteration and climb to about 63% accuracy after 15 iterations (`env/all/correct`). You can look at the printouts for some other metrics of interest:\n\n- `ac_tokens_per_turn`: the number of each tokens in each generated completion\n- `env/all/format`: the fraction of completions that are formatted correctly\n- `env/all/reward/total`: mean total reward (combining format and correctness as defined above)\n- `entropy`: per-token entropy (mean negative log-probability of sampled tokens)\n- `kl_sample_train_{v1,v2}`: two different approximations/estimators of KL divergence between the sampler's and learner's probability distribution (contributed to by numerical differences and rounding noise)\n- `progress/done_frac`: what fraction of the total number of iterations we've completed so far\n- `time/...`: time for different parts of the training loop\n\nYou can also look at the `log_path` directory for detailed metrics, rollout transcripts, and per-trajectory summaries. See [RL Training Outputs](/rl/rl-logging) for the full file reference and parsing examples.\n"
  },
  {
    "path": "docs/rl/rl-envs.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# RL Environments\n\nHere, we'll explain how to create your own RL environments and train on them. First, lets look at the basic classes, which can be found in <CookbookLink path=\"tinker_cookbook/rl/types.py\">`tinker_cookbook.rl.types`</CookbookLink>. As you can see, there's an `Env` interface, corresponding to an RL environment. To write an environment, you need to implement two methods: `initial_observation` and `step`.\n\n```python\nclass Env:\n    \"\"\"\n    Stateful environment that a single agent interacts with.\n    Discard after running for one episode.\n    \"\"\"\n\n    async def initial_observation(self) -> tuple[Observation, StopCondition]:\n        raise NotImplementedError\n\n    async def step(self, action: Action) -> StepResult:\n        raise NotImplementedError\n```\n\nNote that this `Env` operates on *tokens*, rather than strings or messages. Why define it this way, when it's usually more natural to define the logic in terms of strings or messages? We've defined `Env` this way because this interface is what's needed by the *training* code, which needs to know the exact tokens that were sampled, and their logprobs.\n\nWe need to write two more small classes to use this environment in the RL training code. First, since the environment is discarded after a single episode, we need to be able to instantiate new environments in the training loop. We actually build a *group* of environments at a time, which enables multi-agent training or objectives that compare multiple samples (for example, a reward model that acts on a pair of samples).\n\n```python\nclass EnvGroupBuilder:\n    \"\"\"\n    Builds a group of environments.\n    \"\"\"\n\n    async def make_envs(self) -> Sequence[Env]:\n        raise NotImplementedError\n```\n\nThis object creates a group of environments. Often it does the trivial thing of returning a list of copies of the same environment.\n\nFinally, we need a dataset of these EnvGroupBuilders.\n\n```python\nclass RLDataset:\n    \"\"\"\n    Dataset of EnvGroupBuilders.\n    \"\"\"\n\n    def get_batch(self, index: int) -> list[EnvGroupBuilder]:\n        raise NotImplementedError\n```\n\n\nThat's a lot of classes! But their combination gives us a lot of flexibility. In previous implementations (like OpenAI Gym), the dataset is implicitly part of the environment; this structure is more modular and gives us more control over the data loading.\n\n## Building a simple example\n\nYou can find an example of writing a new RL environment in the <CookbookLink path=\"tinker_cookbook/recipes/multiplayer_rl/twenty_questions\">Twenty Questions</CookbookLink> directory.\nHere, we define a multi-step environment, where we're training a question-asking agent, which asks questions to another agent to guess a hidden word.\nIn this case, the answerer model is fixed and is Llama-3.1-8B-Instruct.\nThe player model (which we fine-tune) is also based on that same model.\n\nYou can run the training script as follows:\n\n```bash\npython -m tinker_cookbook.recipes.multiplayer_rl.twenty_questions.train\n```\n"
  },
  {
    "path": "docs/rl/rl-hyperparams.mdx",
    "content": "# RL Hyperparameters\n\nThis guide covers the key hyperparameters for reinforcement learning training, from core settings to advanced configurations.\n\n## Core Hyperparameters\n\n### Learning Rate\n\nSimilar to the [supervised learning setting](../supervised-learning/sl-hyperparams), the learning rate is the most critical hyperparameter choice. We recommend using the guidance presented there as a starting point for RL experiments as well.\n\n\n### Batch and Group Sizes\n\nAs described in our [RL environments](../rl/rl-envs.mdx) documentation, we use two key parameters:\n\n- **`batch_size`**: The number of unique environments or problems used for training\n- **`group_size`**: The number of rollouts performed per unique environment\n\nIf you have limited environments or problems available for training, increase the `group_size` to generate more training data. While the total number of rollouts depends on both parameters, we recommend scaling learning rates proportionally to $\\text{LR} \\propto \\sqrt{\\text{batch\\_size}}$.\n\n## Multiple Updates per Sampling Iteration\n\nThe `num_substeps` parameter controls how many policy weight updates are performed on data sampled from the last policy iteration, similar to PPO and GRPO.\n\n### How it works:\n\n- **`num_substeps = 1` (default)**: Each batch of collected trajectories is used for exactly one optimizer update\n- **`num_substeps > 1`**: The batch of unique environments is split into `num_substeps` mini-batches, where each environment/problem has `group_size` rollouts (we pack all rollouts for a particular environment/problem in the same minibatch). We do a single update step on each mini-batch. Note that our implementation still takes only a single epoch through the data.\n\n### Usage Guidelines:\n\n- The batch size must be divisible by `num_substeps`\n- Our experiments show that `num_substeps = 1` already gives decent performance, but if you would like to experiment with this parameter, we recommend starting with a low value of 2-4 and using the PPO objective.\n- Higher values can lead to update steps that are too out-of-distribution for the policy. Consider limiting the number of updates or decreasing the learning rate when using multiple update steps.\n\n## Advanced Training Configurations\n\n⚠️ **Note**: These features are experimental and may be subject to instabilities. They are currently disabled by default.\n\n### Streaming Minibatch Training\n\nEnable streaming minibatch training by specifying the `StreamMinibatchConfig`. This approach overlaps trajectory sampling and model training, improving overall throughput by submitting training requests as soon as enough rollouts complete, without waiting for all sampling jobs to finish.\n\n**Configuration Parameters:**\n\n- **`groups_per_batch`**: Same as batch size\n- **`num_minibatches`**: Number of minibatches per substep—controls how many individual forward-backward requests we submit. This controls how the work is split.\n\n\n**Important**: This remains on-policy training and is strictly a pipeline efficiency improvement.\n\n### Async Off-Policy Training\n\nAsync training allows the model to train on trajectories generated with slightly older model versions, enabling higher throughput at the cost of some off-policy bias. While Tinker doesn't currently support in-flight weight changes, it supports the \"off-by-K\" async RL approach where multiple model iterations generate data simultaneously. Configure this by setting the `AsyncConfig` object.\n\n**Configuration Parameters:**\n\n- **`max_steps_off_policy`**: Maximum age (in training steps) of trajectories before they're discarded. Essentially, trajectories from policy iterations older than `max_steps_off_policy` steps will not be used.\n- **`groups_per_batch`**: Number of new trajectory groups to accumulate (with a `group_size` number of rollouts each) before updating the current iteration of the model. Note: This is separate from the batch size used for dataset construction.\n\n**Usage Guidelines:**\n\n- Async RL is appropriate for applications with long and heterogeneous rollouts, such as very long CoT models, multi-hop tool use, or agentic workflows\n- Start with a small value for `max_steps_off_policy` (less than 5)\n\n\n\n## Monitoring and Run Health\n\nUsing policy-gradient algorithms with off-policy data can significantly degrade performance or even crash the policy, making monitoring essential during training.\n\n### KL Divergence Monitoring\n\nThe current implementation logs the KL divergence between the data generation policy and the current learner: $\\mathbb{D}_{KL}[\\pi_{\\text{sampler}}(\\cdot|x)||\\pi_{\\theta}(\\cdot|x)]$ using two separate estimators ([Schulman 2020](http://joschu.net/blog/kl-approx.html)):\n\n- `kl_sample_train_v1`\n- `kl_sample_train_v2`\n\n\nA few important notes to keep in mind:\n- Even with full on-policy training, the divergence between sampling and learning policies will not be exactly zero ([He 2025](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/)) due to implementation details\n- In our experience training is stable with KL divergence below 0.01\n- If KL divergence crosses a recommended threshold, this indicates a numerical instability or potential issue with the training run\n"
  },
  {
    "path": "docs/rl/rl-logging.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# RL Training Outputs\n\nEach RL training run writes files to `log_path`. This page describes each file and how to extract data from it.\n\n## Files written to `log_path`\n\n| File | Format | Contents |\n|------|--------|----------|\n| `metrics.jsonl` | JSONL | One JSON object per training iteration with all scalar metrics |\n| `config.json` | JSON | Serialized training config (hyperparams, model, dataset, etc.) |\n| `checkpoints.jsonl` | JSONL | Checkpoint metadata (paths, loop state for resume) |\n| `train_iteration_NNNNNN.html` | HTML | Human-readable logtree report for training rollouts |\n| `train_iteration_NNNNNN_logtree.json` | JSON | Machine-readable export of the same logtree trace |\n| `train_iteration_NNNNNN_rollout_summaries.jsonl` | JSONL | One JSON object per trajectory with rewards, metrics, and step-level data |\n| `eval_<name>_iteration_NNNNNN.html` | HTML | Logtree report for eval rollouts |\n| `eval_<name>_iteration_NNNNNN_logtree.json` | JSON | Machine-readable export of eval logtree trace |\n| `eval_<name>_iteration_NNNNNN_rollout_summaries.jsonl` | JSONL | Per-trajectory eval data (for `RLTestSetEvaluator`) |\n| `code.diff` | text | Git diff at the time training started |\n\n`<name>` is the evaluator name (sanitized for filenames); iteration numbers are zero-padded to 6 digits.\n\n## `metrics.jsonl`\n\nEach line is a JSON object keyed by metric name. Common keys (varies by env and config):\n\n- `progress/batch`, `progress/done_frac` — iteration index and completion fraction\n- `env/all/reward/total` — mean total reward across all trajectories\n- `env/all/<metric>` — env-emitted metrics (e.g., `format_parse`, `correct`)\n- `ac_tokens_per_turn` — mean generated tokens per turn\n- `entropy` — per-token entropy\n- `kl_sample_train_v1`, `kl_sample_train_v2` — KL divergence estimators\n- `optim/lr` — learning rate\n- `time/...` — wall-clock timings for different stages\n\n```python\nimport pandas as pd\n\ndf = pd.read_json(\"path/to/metrics.jsonl\", lines=True)\ndf.plot(x=\"progress/batch\", y=\"env/all/reward/total\")\n```\n\n## `*_rollout_summaries.jsonl`\n\nOne line per trajectory. Best for aggregate analysis (reward distributions, per-step metrics).\n\n```python\nimport json\n\nwith open(\"train_iteration_000010_rollout_summaries.jsonl\") as f:\n    trajectories = [json.loads(line) for line in f]\n\n# Each trajectory has:\n# - metadata: schema_version, split, iteration, group_idx, traj_idx, tags, sampling_client_step\n# - episode totals: total_reward, final_reward, trajectory_metrics, final_ob_len\n# - steps: list of {step_idx, ob_len, ac_len, reward, episode_done, metrics, logs}\n```\n\n## `*_logtree.json`\n\nThe logtree JSON contains full rollout transcripts: prompts, model responses, grading details, and reward breakdowns. Use this when you need the actual text content of rollouts.\n\nTop level: `title`, `started_at`, `path`, `root`. `root` is a tree of nodes, each with `tag`, `attrs`, and `children` (either text strings or nested nodes).\n\nSome nodes carry a **`data`** field with structured content. Use `data` to extract typed data like conversation messages:\n\n```python\nimport json\n\ndef find_conversations(node):\n    \"\"\"Recursively find all nodes with conversation data.\"\"\"\n    results = []\n    if isinstance(node, dict):\n        if node.get(\"data\", {}).get(\"type\") == \"conversation\":\n            results.append(node[\"data\"])\n        for child in node.get(\"children\", []):\n            if isinstance(child, dict):\n                results.extend(find_conversations(child))\n    return results\n\nwith open(\"eval_test_iteration_000020_logtree.json\") as f:\n    trace = json.load(f)\n\nfor conv in find_conversations(trace[\"root\"]):\n    for msg in conv[\"messages\"]:\n        print(f\"{msg['role']}: {msg['content'][:100] if isinstance(msg['content'], str) else '...'}\")\n```\n\nNote: `num_groups_to_log` (default: 4) controls how many trajectory groups get detailed env-level logging. Groups beyond this limit have no rollout content in the logtree — only the `Trajectory Details` section (turn-level stats) is always present.\n\n## `config.json`\n\nSerialized `chz` config capturing all training hyperparameters. Useful for reproducing a run or comparing configs across experiments.\n\n## `checkpoints.jsonl`\n\nEach line records a saved checkpoint with its path and the loop state at save time. Used by the resume logic to pick up where training left off.\n"
  },
  {
    "path": "docs/rl/rl-loops.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# Reinforcement Learning Training Loop\n\nWe've provided a simple RL training loop in <CookbookLink path=\"tinker_cookbook/recipes/rl_loop.py\">rl_loop.py</CookbookLink>, which avoids using our environment classes and instead defines the data loading and rollouts in a more self-contained way. This is for people who like to write their own training loops or learn about how things work under the hood. Our more performant implementation in <CookbookLink path=\"tinker_cookbook/rl/train.py\">rl/train.py</CookbookLink> does basically the same thing, but with some performance optimizations, and with some additional features like periodic evals.\n\nYou can run the RL training loop using:\n```\npython -m tinker_cookbook.recipes.rl_loop\n```\n\nThe default config should write the results to `/tmp/tinker-examples/rl-loop`. The experiment should be completed after 57 steps of training. You can plot the reward curve as follows:\n```python\nimport pandas\nimport matplotlib.pyplot as plt\n\nmetrics_path = \"/tmp/tinker-examples/rl-loop/metrics.jsonl\"\ndf = pandas.read_json(metrics_path, lines=True)\nplt.plot(df[\"reward/total\"], label=\"reward/total\")\nplt.legend()\nplt.show()\n```\n\nYou should see a plot like this:\n![Reward as a function of steps](./images/rl_loop_reward.png)\n"
  },
  {
    "path": "docs/rl/sequence-extension.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# Sequence Extension Property in Multi-Turn RL\n\nWhen running reinforcement learning with multi-turn conversations, the way you render observations at each timestep has important implications for compute efficiency. This document explains the **extension property** and how it affects training and sampling.\n\n## What is the Extension Property?\n\nA sequence of observations has the **extension property** if each successive observation contains all previous observations and actions as a prefix. In other words, the context grows monotonically by appending new tokens to the end.\n\nWhen this property holds, multiple timesteps can be merged into a single training datum, the KV-cache can be reused during sampling, and compute scales as O(T) rather than O(T^2) for a trajectory of length T.\n\n## Example 1: Qwen3 with Thinking Visible (Extension Holds)\n\nWhen using `Qwen3Renderer` with `strip_thinking_from_history=False`, the full conversation history (including `<think>` blocks) is preserved at each timestep. Consider a two-turn math conversation:\n\n**Timestep 1:**\n<div className=\"example\">\n<span className=\"prompt\">User: What is 2+2?<br/><br/>Assistant: </span><span className=\"completion\">&lt;think&gt;Let me calculate...&lt;/think&gt; 4<br/><br/>User:</span>\n</div>\n\n**Timestep 2:**\n<div className=\"example\">\n<span className=\"prompt\">User: What is 2+2?<br/><br/>Assistant: &lt;think&gt;Let me calculate...&lt;/think&gt; 4<br/><br/>User: What is 3+3?<br/><br/>Assistant: </span><span className=\"completion\">&lt;think&gt;Let me calculate...&lt;/think&gt; 6<br/><br/>User:</span>\n</div>\n\nNotice that the observation (green) at timestep 2 contains the entire timestep 1 sequence as a prefix. The new observation just appends `What is 3+3?\\n\\nAssistant: ` to the end. This is the **extension property**.\n\nBecause extension holds, the RL code can merge both timesteps into a **single Datum**:\n\n<div className=\"example\">\n<span className=\"prompt\">User: What is 2+2?<br/><br/>Assistant: </span><span className=\"completion\">&lt;think&gt;Let me calculate...&lt;/think&gt; 4<br/><br/>User:</span><span className=\"prompt\"> What is 3+3?<br/><br/>Assistant: </span><span className=\"completion\">&lt;think&gt;Let me calculate...&lt;/think&gt; 6<br/><br/>User:</span>\n</div>\n\nGreen = observation tokens (loss weight = 0). Red = action tokens (loss weight > 0).\n\n## Example 2: Qwen3 with Thinking Hidden (Extension Breaks)\n\nWhen using `Qwen3Renderer` with the default `strip_thinking_from_history=True`, the `<think>...</think>` blocks are stripped from previous assistant messages. This matches how Qwen3 models were post-trained by the Qwen team.\n\n**Timestep 1:**\n<div className=\"example\">\n<span className=\"prompt\">User: What is 2+2?<br/><br/>Assistant: </span><span className=\"completion\">&lt;think&gt;Let me calculate...&lt;/think&gt; 4<br/><br/>User:</span>\n</div>\n\n**Timestep 2:**\n<div className=\"example\">\n<span className=\"prompt\">User: What is 2+2?<br/><br/>Assistant: 4<br/><br/>User: What is 3+3?<br/><br/>Assistant: </span><span className=\"completion\">&lt;think&gt;Let me calculate...&lt;/think&gt; 6<br/><br/>User:</span>\n</div>\n\nThe observation at timestep 2 is **not** an extension of timestep 1's full sequence. The `<think>Let me calculate...</think>` portion was stripped, so the prefix doesn't match. The RL code must create **two separate Datums**:\n\n**Datum 1:**\n<div className=\"example\">\n<span className=\"prompt\">User: What is 2+2?<br/><br/>Assistant: </span><span className=\"completion\">&lt;think&gt;Let me calculate...&lt;/think&gt; 4<br/><br/>User:</span>\n</div>\n\n**Datum 2:**\n<div className=\"example\">\n<span className=\"prompt\">User: What is 2+2?<br/><br/>Assistant: 4<br/><br/>User: What is 3+3?<br/><br/>Assistant: </span><span className=\"completion\">&lt;think&gt;Let me calculate...&lt;/think&gt; 6<br/><br/>User:</span>\n</div>\n\nThis results in more compute during training (two forward/backward passes instead of one) and prevents KV-cache reuse during sampling. For a trajectory of T timesteps, compute scales as O(T²) instead of O(T).\n\n## The Tradeoff\n\n**Keeping thinking visible** (`strip_thinking_from_history=False`) gives you O(T) compute scaling, allows packing sequences together in training batches, and enables KV-cache reuse during sampling. The downside is that context grows faster since all thinking tokens are retained, so you may hit context length limits sooner.\n\n**Stripping thinking** (`strip_thinking_from_history=True`, the default) keeps context smaller but breaks the extension property, leading to O(T²) compute scaling.\n\nNote that while stripping thinking matches Qwen3's original post-training distribution, with RL fine-tuning the model should quickly adapt to the new situation where thinking is preserved. So \"distribution match\" might not be a major concern in practice.\n\n## How the RL Code Handles This\n\nThe RL training code in <CookbookLink path=\"tinker_cookbook/rl/data_processing.py\">`data_processing.py`</CookbookLink> automatically detects whether consecutive timesteps satisfy the extension property. The key function is `trajectory_to_data`:\n\n```python\ndef trajectory_to_data(traj: Trajectory, traj_advantage: float) -> list[tinker.Datum]:\n    \"\"\"\n    Return one or more Datum objects corresponding to the trajectory.\n    If the sequence grows by appending, i.e., each successive observation contains\n    the previous observation+action as a prefix, then we can return a single Datum.\n    However, if we get a sequence that's not an extension of the previous sequence,\n    then that results in a new Datum.\n    \"\"\"\n```\n\nWhen rendering your conversations, be aware of whether your renderer has the extension property. You can check programmatically via `renderer.has_extension_property`. For `Qwen3Renderer`:\n- `strip_thinking_from_history=False` → `has_extension_property=True`\n- `strip_thinking_from_history=True` (default) → `has_extension_property=False`\n\n**Note on sampling:** The training code automatically merges timesteps when possible. Sampling infrastructure doesn't yet adjust billing based on KV-cache hits, but this is planned for a future release.\n\n## Advanced: Periodic Compaction\n\nA hybrid approach is to use **periodic compaction**: keep thinking visible most of the time (preserving extension), but periodically clear old thinking blocks from the context.\n\n**How it works:**\n- For turns 1-10, keep all thinking visible (extension holds, single datum)\n- At turn 11, strip thinking from turns 1-10 (extension breaks once, new datum starts)\n- For turns 11-20, keep thinking visible again (extension holds)\n- Repeat every N turns\n\nHere's what the datums look like with compaction every 3 turns:\n\n**Datum 1 (turns 1-3):**\n<div className=\"example\">\n<span className=\"prompt\">User: Q1<br/>Assistant: </span><span className=\"completion\">&lt;think&gt;...&lt;/think&gt; A1<br/>User:</span><span className=\"prompt\"> Q2<br/>Assistant: </span><span className=\"completion\">&lt;think&gt;...&lt;/think&gt; A2<br/>User:</span><span className=\"prompt\"> Q3<br/>Assistant: </span><span className=\"completion\">&lt;think&gt;...&lt;/think&gt; A3<br/>User:</span>\n</div>\n\n**Datum 2 (turns 4-6, thinking from turns 1-3 stripped):**\n<div className=\"example\">\n<span className=\"prompt\">User: Q1<br/>Assistant: A1<br/>User: Q2<br/>Assistant: A2<br/>User: Q3<br/>Assistant: A3<br/>User: Q4<br/>Assistant: </span><span className=\"completion\">&lt;think&gt;...&lt;/think&gt; A4<br/>User:</span><span className=\"prompt\"> Q5<br/>Assistant: </span><span className=\"completion\">&lt;think&gt;...&lt;/think&gt; A5<br/>User:</span><span className=\"prompt\"> Q6<br/>Assistant: </span><span className=\"completion\">&lt;think&gt;...&lt;/think&gt; A6<br/>User:</span>\n</div>\n\nThis approach breaks extension only every N timesteps instead of every timestep, keeps context size bounded (old thinking doesn't accumulate forever), and amortizes the recomputation cost over N turns.\n\nTo implement this, you would modify your environment or renderer to periodically transform the conversation history, stripping `<think>` blocks from messages older than N turns.\n\n## Summary\n\nFor `Qwen3Renderer`:\n- `strip_thinking_from_history=False` → Extension holds → Use for long trajectories where compute efficiency matters\n- `strip_thinking_from_history=True` (default) → Extension breaks → Use for short trajectories, or when you want minimal changes from base model behavior\n- Periodic compaction → Best of both worlds when you need efficiency with bounded context\n\nWhen designing your RL environment, consider how many turns you expect and whether the O(T) vs O(T²) difference will be significant for your use case.\n"
  },
  {
    "path": "docs/rl.mdx",
    "content": "import { CookbookLink } from '../components/CookbookLink'\n\n# Reinforcement learning\n\nReinforcement learning (RL) means learning from trial and error. Whereas in supervised learning, we're given input-output pairs, in RL, we're given inputs (prompts) and reward functions (i.e., a function for scoring candidate outputs). RL algorithms need to discover what good outputs look like.\n\nHere are a few different types of RL training that we support in the Tinker Cookbook:\n\n- *RL with Verifiable Rewards*: this is when we do RL on a reward function that checks model outputs using a program. Typically, the reward function checks the candidate answer against a reference answer, or, in coding cases, it may check if the candidate solution passes some unit tests. RLVR is especially suitable for teaching models to do reasoning (with chain-of-thought) and multi-step tool use (e.g., debugging and iterative modification of programs).\n- *RL on Human Feedback*: here, we assume we have an objective that can't be calculated by a simple program, and it requires some human judgement. For example, we typically want to optimize our models for helpfulness, which includes being clear, informative, and interesting. For RLHF, we train a *preference model* using supervised learning to match human judgement, scoring or ranking candidate outputs. Then we do RL on the preference model's scores. See the [Preferences](/preferences) section for more details.\n\nWe'll first show how to do small RL runs in the RLVR setting, then we'll show you how to define your own RL environments and train on them, then we'll provide examples for larger-scale or more complicated training setups.\n\n\nEvery training iteration writes human-readable HTML reports and machine-readable JSON files (metrics, rollout transcripts, per-trajectory summaries) to `log_path`. These are designed for both manual inspection and programmatic analysis — point an agent at your `log_path` to extract reward curves, model responses, and grading details. See [RL Training Outputs](/rl/rl-logging) for the full file reference and parsing examples.\n\nWe anticipate that people will want to use Tinker for RL in a few different ways:\n\n- Creating a specialist model that's SoTA at a specific skill, which existing models haven't been trained on. In this case, you'll want to start with a post-trained model that's already strong, and then do RL on an environment you've defined. See [RL Environments](/rl/rl-envs).\n- Doing research on post-training pipelines. In this case, you'll probably want to chain together SL and RL and runs with different data mixes, environments, and reward functions. See our [RLHF example](/preferences/rlhf-example).\n- Doing research on RL algorithms. Here, you'll probably want to find some existing environments to use as benchmarks, and either modify our provided training code (<CookbookLink path=\"tinker_cookbook/rl/train.py\">rl/train.py</CookbookLink>) or write your own minimal training loop. We've provided a [minimal training loop](/rl/rl-loops) that you can use as a starting point.\n"
  },
  {
    "path": "docs/save-load.mdx",
    "content": "# Saving and loading weights and optimizer state\n\nDuring training, you'll need to save checkpoints for two main purposes: *sampling* (to test your model) and *resuming training* (to continue from where you left off). The `TrainingClient` provides three methods to handle these cases:\n\n1. `save_weights_for_sampler()`: saves a copy of the model weights that can be used for sampling.\n2. `save_state()`: saves the weights and the optimizer state. You can fully resume training from this checkpoint.\n3. `load_state()`: load the weights and the optimizer state. You can fully resume training from this checkpoint.\n\nNote that (1) is faster and requires less storage space than (2).\n\nBoth `save_*` functions require a `name` parameter---a string that you can set to identify the checkpoint within the current training run. For example, you can name your checkpoints `\"0000\"`, `\"0001\"`, `\"step_1000\"`, etc.\n\nThe return value contains a `path` field, which is a fully-qualified path, which will look something like `tinker://<model_id>/<name>`. This path is persistent and can be loaded later by a new `ServiceClient` or `TrainingClient`.\n\n### Example: Saving for sampling\n\n```python\n# Setup\nimport tinker\nservice_client = tinker.ServiceClient()\ntraining_client = service_client.create_lora_training_client(\n    base_model=\"meta-llama/Llama-3.2-1B\", rank=32\n)\n\n# Save a checkpoint that you can use for sampling\nsampling_path = training_client.save_weights_for_sampler(name=\"0000\").result().path\n\n# Create a sampling client with that checkpoint\nsampling_client = service_client.create_sampling_client(model_path=sampling_path) #\n```\n\n**Shortcut:** Combine these steps with:\n\n```python\nsampling_client = training_client.save_weights_and_get_sampling_client(name=\"0000\")\n```\n\n### Example: Saving to resume training\n\nUse `save_state()` and `load_state()` when you need to pause and continue training with full optimizer state preserved:\n\n```python\n# Save a checkpoint that you can resume from\nresume_path = training_client.save_state(name=\"0010\").result().path\n\n# Load that checkpoint\ntraining_client.load_state(resume_path)\n```\n\n### When to use `save_state()` and `load_state()`:\n\n\n- Multi-step training pipelines (e.g. supervised learning followed by reinforcement learning)\n- Adjusting hyperparameters or data mid-run\n- Recovery from interruptions or failures\n- Any scenario where you need to preserve exact optimizer state (momentum, learning rate schedules, etc.)\n"
  },
  {
    "path": "docs/supervised-learning/prompt-distillation.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# Prompt Distillation\n\nPrompt distillation is a training technique in which a model is optimized to behave as though it had been provided with a long and complex prompt, without requiring access to that prompt during inference.\n\nAt a high level, this procedure involves two main steps:\n- **Creation of distillation data**: A teacher prompt, which is typically lengthy and highly detailed, provides explicit, step-by-step instructions. A teacher model uses this prompt to generate responses for a set of queries.\n- **Training the student model**: A student model is then trained (or fine-tuned) on the distilled dataset, thereby learning to reproduce the essential behaviors and reasoning encoded in the teacher’s instructions.\n\n---\n\n## Overview\n\nLet $f_T$ and $f_S$ denote the teacher and student models, respectively. Given an instruction prompt $P$ and a query $q_i$, the teacher model generates a response $r_i$:\n\n$$\nr_i = f_T([P, q_i])\n$$\n\nHere, the prompt $P$ and the query $q_i$ are concatenated to form the input to the teacher model $f_T$. For a dataset of queries $Q = \\{q_i \\mid 1 \\leq i \\leq D\\}$, we obtain a corresponding set of teacher responses $R = \\{r_i \\mid 1 \\leq i \\leq D\\}$.\n\nThe distillation training dataset is defined as the set of query–response pairs (excluding the original prompt):\n\n$$\nT = \\{(q_i, r_i) \\mid 1 \\leq i \\leq D\\}.\n$$\n\nThe student model $f_S$ is then trained to minimize the cross-entropy loss:\n\n$$\n\\ell(f_S(q_i), r_i) = \\ell(f_S(q_i), f_T([P, q_i])).\n$$\n\n---\n\n## Example\n\nThe Tinker Cookbook provides a prompt distillation recipe tailored for a language classification task. The objective is straightforward: given a text query, the model should predict a two-character code corresponding to the language of the input. The set of possible labels is:\n```\nar (Arabic), de (German), el (Greek), en (English), es (Spanish), fr (French), hi (Hindi), ru (Russian), tr (Turkish), ur (Urdu), vi (Vietnamese), zh (Chinese - Simplified), ot (Other/Unknown).\n```\n\nThe recipe in <CookbookLink path=\"tinker_cookbook/recipes/prompt_distillation/create_data.py\">recipes/prompt_distillation/create_data.py</CookbookLink> also includes handling strategies for inputs containing code, numerical content, or multiple languages.\n\nIn the example below, the same model (`Qwen/Qwen3-30B-A3B`) is used as both teacher and student, though in general they need not be identical.\n\n---\n\n### Step 1: Generate Training Data\n\nCreate prompt distillation data using the teacher model using <CookbookLink path=\"tinker_cookbook/recipes/prompt_distillation/create_data.py\">recipes/prompt_distillation/create_data.py</CookbookLink>:\n\n```bash\npython -m tinker_cookbook.recipes.prompt_distillation.create_data \\\n  output_file=/tmp/tinker-datasets/prompt_distillation_lang.jsonl\n```\n\nThis command will:\n- Use the configured teacher model to generate language classification examples\n- Save the distilled dataset to the specified output file\n- Create diverse training examples suitable for student model fine-tuning\n\n### Step 2: Train the Student Model\n\nFine-tune a student model on the distillation data using <CookbookLink path=\"tinker_cookbook/recipes/prompt_distillation/train.py\">recipes/prompt_distillation/train.py</CookbookLink>:\n\n```bash\npython -m tinker_cookbook.recipes.prompt_distillation.train\n```\n\nThe training script will:\n- Load the generated distillation dataset\n- Apply optimized training configurations\n- Fine-tune the student model for language classification\n\n### Step 3: Test Your Model\n\nOnce training is complete, you can test your distilled model by sampling from the trained model to verify its performance on language classification tasks.\n\n## Advanced Configuration\n\nThe prompt distillation recipe can be customized for different scenarios:\n\n- **Teacher model selection**: Choose different base models based on your requirements\n- **Sampling strategies**: Adjust temperature and other generation parameters\n- **Data volume**: Scale the number of generated examples based on your needs\n- **Training hyperparameters**: Fine-tune learning rates and other training settings\n"
  },
  {
    "path": "docs/supervised-learning/sl-basic.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# Basic Supervised Learning\n\nThis guide walks you through running your first supervised learning experiment using Tinker's built-in training loop.\n\n## Quick start\n\nWe've provided an implementation of supervised learning in <CookbookLink path=\"tinker_cookbook/supervised/train.py\">train_cli.py</CookbookLink>. To use this training loop, you'll need to create a `Config` object with the data and parameters.\n\nWe've provided a ready-to-run example that fine-tunes Llama-3.1-8B on a small instruction-following dataset in <CookbookLink path=\"tinker_cookbook/recipes/sl_basic.py\">sl_basic.py</CookbookLink>. You can run it from the command line as follows:\n\n```bash\npython -m tinker_cookbook.recipes.sl_basic\n```\n\nThis script fine-tunes the base (pretrained) model on a small dataset called [NoRobots](https://huggingface.co/datasets/HuggingFaceH4/no_robots), created by Hugging Face.\n\n### What you'll see during training\n\n- Each step you should see a printout of the train and test loss, along with other stats like timing.\n- The training script will also print out what the data looks like, with predicted tokens (weight=1) in green and context tokens (weight=0) in yellow.\n- The training script will write various logs and checkpoint info to the `log_path` directory, which is set to `/tmp/tinker-examples/sl_basic` in the example script.\n\n### Understanding the output files\nLooking at the `log_path` directory, you will find several files of interest:\n- `metrics.jsonl`: the training metrics that also were printed to the console. You can load and plot them like this:\n\n    ```python\n    import pandas\n    import matplotlib.pyplot as plt\n    df = pandas.read_json(\"/tmp/tinker-examples/sl_basic/metrics.jsonl\", lines=True)\n    plt.plot(df['train_mean_nll'], label='train_loss')\n    plt.plot(df['test/nll'].dropna(), label='test_loss')\n    plt.legend()\n    plt.show()\n    ```\nYou should see a plot like this:\n![Train and test loss as a function of steps](./images/train_test_loss.png)\n\n\n- `checkpoints.jsonl`: the checkpoints that were saved during training. Recall from [Saving and Loading](/save-load) that there are (currently) two kinds of checkpoints: one that has \"/sampler_weights/\" in the path (used for sampling), and the other that has \"/weights/\" in the path (includes full optimizer state, used for resuming training). If you interrupt the training script, then run it again, it will ask you if you want to resume training. If you choose to do so, it'll load the last (full state) checkpoint from this file.\n- `config.json`: the configuration that you used for training.\n\nIn the `sl_basic` script, you'll see that there's also some disabled code (under `if 0:`) that shows how to use your own dataset, specified as a JSONL file, provided in the format of <CookbookLink path=\"tinker_cookbook/example_data/conversations.jsonl\">conversations.jsonl</CookbookLink>.\n"
  },
  {
    "path": "docs/supervised-learning/sl-hyperparams.mdx",
    "content": "# Supervised Learning Hyperparameters\n\nSuccessful LLM fine-tuning requires careful hyperparameter tuning. While the most accurate approach is to sweep over ranges and selecting values that minimize loss or maximize eval performance for each hyperparameter, this is often time-consuming and expensive. This guide provides some starting recommendations for the most important hyperparameters.\n\n\n## Learning rate\n\nThe most important hyperparameter is generally the learning rate (LR). Our current best estimate of optimal LR for a model $m$ is the following:\n\n$$ LR(m) = lr_{base} · M_{LoRA} · \\Big(\\frac{2000}{H_m}\\Big)^{P_m} $$\n\nwhere $lr_{base}$ is a constant base LR, $M_{LoRA}$ is a multiplier applied when using LoRA (1 if using full-finetuning), $H_m$ is the hidden size of the model $m$, and $P_m$ is a model-specific exponent adjustment. Importantly, this function is independent of the LoRA rank.\n\nOur current best estimates are the following: $lr_{base} = 5e-5$,\n$M_{LoRA} = 10$, $P_m = 0.0775$ for Qwen models and $P_m = 0.781$ for Llama models.\n\n### Getting the recommended learning rate\nYou can use the following function to get the recommended LR for any Llama or Qwen model:\n```\nfrom tinker_cookbook.hyperparam_utils import get_lr\nmodel_name = \"meta-llama/Llama-3.2-1B\"\nrecommended_lr = get_lr(model_name)\nprint(f\"Recommended LR: {recommended_lr}\")\n```\n### Validation\nWe validated this formula across diverse supervised fine-tuning experiments, varying datasets, dataset sizes, batch_sizes and lora_ranks.\n\nUsing our LR estimates resulted in \\<0.5% regret compared to exhaustive hyperparameter sweeps, where regret is defined as:\n\nWe can define the regret of using any lr as the following:\n$$regret(lr') = \\frac{loss(lr') - min_{lr} loss(lr)}{min_{lr} loss(lr)}$$\n\n\n## Batch size\n\nBatch size is the second-most important hyperparameter; it significantly affects both training efficiency and final performance.\n\nFor small batch sizes, there's a phenomenon of *perfect scaling*, where the LR and batchsize should be varied together as $LR \\propto \\sqrt{B}$, and the learning curve only depends on $\\frac{LR}{\\sqrt{B}}$. See [Shallue et al. (2018)](https://arxiv.org/abs/1811.03600) for an example in the training-from-scratch setting.\n\nWhen fine-tuning LLMs, we're often in a regime where smaller batch sizes give better performance, at the cost of longer training time; moreover, the $LR \\propto \\sqrt{B}$ scaling doesn't always hold. When doing SL fine-tuning, we recommend using smaller batch sizes like 128, depending on your tolerance for longer training time.\n\nFor best results, you should aim for at least 100 steps of training (but usually get best results with 1000 or more).\n\n⚠️ Note: Our batch size recommendations are based on preliminary findings and ongoing research. We're not confident about them!\n"
  },
  {
    "path": "docs/supervised-learning/sl-loop.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# Supervised Learning Training Loop\n\nWe've provided a simple SL training loop in <CookbookLink path=\"tinker_cookbook/recipes/sl_loop.py\">sl_loop.py</CookbookLink>, which avoids using our dataset classes and instead defines the data loading in a more self-contained way. This is for people who like to write their own training loops or learn about how things work under the hood. Our more performant implementation in <CookbookLink path=\"tinker_cookbook/supervised/train.py\">supervised/train.py</CookbookLink> does basically the same thing, but with some performance optimizations, and with some additional features like periodic evals.\n"
  },
  {
    "path": "docs/supervised-learning/sweep-case-study.mdx",
    "content": "import { CookbookLink } from '../../components/CookbookLink'\n\n# Sweep case study\n\nIn [Supervised Learning Hyperparameters](./sl-hyperparams), we introduced default hyperparameters as a starting point. While defaults are useful, optimal values are often task-specific. A hyperparameter sweep---systematically testing values across a range---is a more reliable way to identify the best settings for your use case.\n\nThis guide demonstrates how to sweep over the **learning rate (LR)** to find an optimal value.\n\n## Why sweep the learning rate?\n\nThe learning rate is typically the most impactful hyperparameter. While our default recommendations perform well (usually \\<0.5% regret), you can often achieve even better results by sweeping to find the task-specific optimum.\n\n\n## Setup\n\nWe use the simple supervised learning training loop in\n<CookbookLink path=\"tinker_cookbook/recipes/sl_loop.py\">sl_loop.py</CookbookLink>, which trains a Llama-3.1-8B model.\n\nTo retrieve the model’s default learning rate recommendation:\n```\nfrom tinker_cookbook.hyperparam_utils import get_lr\nprint(get_lr(\"meta-llama/Llama-3.1-8B\"))\n```\nThis should output\n```\n0.0002856415043086949  # ≈ 2.8e-4\n```\nThis default value provides a baseline. A common best practice is to sweep one order of magnitude above and below the default. For this case, we sweep over: $LR \\in [1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3]$\n\n\n\n## Running the sweep\nLaunch experiments in parallel, using separate terminal windows for each LR value. For example:\n```bash\npython -m tinker_cookbook.recipes.sl_loop learning_rate=0.003 log_path=/tmp/sft-lr-sweep/lr-0.003\npython -m tinker_cookbook.recipes.sl_loop learning_rate=0.001 log_path=/tmp/sft-lr-sweep/lr-0.001\npython -m tinker_cookbook.recipes.sl_loop learning_rate=0.0003 log_path=/tmp/sft-lr-sweep/lr-0.0003\npython -m tinker_cookbook.recipes.sl_loop learning_rate=0.0001 log_path=/tmp/sft-lr-sweep/lr-0.0001\npython -m tinker_cookbook.recipes.sl_loop learning_rate=0.00003 log_path=/tmp/sft-lr-sweep/lr-0.00003\npython -m tinker_cookbook.recipes.sl_loop learning_rate=0.00001 log_path=/tmp/sft-lr-sweep/lr-0.00001\n```\nYou can also automate this process by writing a script that spawns multiple tmux windows and launches experiments programmatically. This is especially useful for larger sweeps.\n\n\n## Collecting Results\nAfter the experiments are complete, you can read the `metrics.jsonl` files:\n```python\nfrom glob import glob\nimport pandas\nimport os\nimport json\n\ndata = []\nfor fname in sorted(glob(os.path.expanduser(\"/tmp/sft-lr-sweep/*/metrics.jsonl\"))):\n    df = pandas.read_json(fname, lines=True)\n    # make sure the experiment is completed\n    if len(df) == 0 or df[\"progress\"].iloc[-1] < 0.98:\n        continue\n    config_fname = fname.replace(\"metrics.jsonl\", \"config.json\")\n    with open(config_fname, \"rb\") as f:\n        metadata = json.load(f)\n    data.append({\n        \"fname\": fname,\n        \"learning_rate\": metadata[\"learning_rate\"],\n        \"final_loss\": df[\"train_mean_nll\"].iloc[-1].item()\n    })\n\nprint(f\"Read metrics for {len(data)} experiments\")\n```\nIf all the experiments are completed, the above code should print:\n```\nRead metrics for 6 experiments\n```\n\n## Visualizing the Sweep\nPlot the `final_loss` as a function of `learning_rate`:\n```python\nimport matplotlib.pyplot as plt\ndf = pandas.DataFrame(data)\nplt.plot(df[\"learning_rate\"], df[\"final_loss\"], marker='o')\nplt.axhline(y=df[\"final_loss\"].min(), color=\"green\", linestyle=\"--\")\nplt.ylim(1.65, 1.8)\nplt.xscale(\"log\")\nplt.xlabel(\"Learning Rate (log scale)\")\nplt.ylabel(\"Final Loss\")\nplt.title(\"Final Loss vs Learning Rate\")\nplt.show()\n```\nYou should see a U-shaped curve, similar to this:\n![final_loss_vs_lr](./images/lr_sweep.png)\n\nIf the full U-curve is not visible in your setting, expand the sweep range by adding more LR values.\n\n\n## Determining the Optimal LR\nThe optimal learning rate is the one that minimizes the loss. The plot above shows that the optimal LR is `3e-4` which you can also calculate by finding the minima:\n```\noptimal_lr = df[\"learning_rate\"][df[\"final_loss\"].idxmin()]\nprint(f\"The optimal LR is {optimal_lr:.2e}\")\n```\nExpected output:\n```\nThe optimal LR is 3.00e-04\n```\n\nNote that the optimal LR in our sweep (`3e-4`) is very close to the default LR (`2.8e-4`). However, task-specific sweeps can still provide marginal improvements and greater confidence in your hyperparameter choices.\n\n## Next steps\nNow that you've identified the optimal learning rate:\n1. Retrain with the optimal LR for your production run\n2. Consider sweeping other hyperparameters like batch size, warmup steps, or weight decay\n3. Use the optimal LR as a baseline for future experiments on similar tasks\n"
  },
  {
    "path": "docs/supervised-learning.mdx",
    "content": "import { CookbookLink } from '../components/CookbookLink'\n\n# Cookbook: Supervised learning\n\nThis section takes you through examples from the Tinker Cookbook that relate to supervised learning.\n\nIn general, supervised learning (SL) means learning an input-output mapping from labeled data. In the context of language model fine-tuning, this means **minimizing a weighted cross-entropy loss** on token sequences---equivalently, maximizing the log-probability of the specified target tokens.\n\nThere are a few ways that SL is commonly used in LLM fine-tuning pipelines:\n\n- *Instruction tuning*: This is the first step in post-training pipelines, applied to the base (raw, pretrained) model. Typically, we do SL on a high-quality dataset that demonstrates the correct format and style, while boosting the model's reasoning and instruction-following.\n- *Context distillation* / *prompt distillation*: let's say we have a generic model that can do chat / instruction following / reasoning, but we want to adjust how it behaves in a certain scenario. We can add some instructions to the system message of our model. However, the system message might grow impractically long and start ignoring some of its instructions. So it's often better to create a supervised dataset on a narrow prompt distribution, with a shorter set of instructions that that are targeted at these prompts.\n\nWe'll cover both of these use cases in this documentation and related Cookbook code.\n\nThe library code implementing supervised learning can be found in the <CookbookLink path=\"tinker_cookbook/supervised\">`supervised`</CookbookLink> directory.\n"
  },
  {
    "path": "docs/support.mdx",
    "content": "# Support\n\n## Get Tinker Support\n\nOur official support email is [tinker@thinkingmachines.ai](mailto:tinker@thinkingmachines.ai) for customer questions and issues. You can also find helpful community discussion on [the Tinker Discord](https://discord.gg/KqqEZNX88c).\n"
  },
  {
    "path": "docs/training-sampling.mdx",
    "content": "import { Callout } from 'nextra/components'\n\n# Getting started with training and sampling\n\nIn this guide, we'll step you through using the Tinker Python library to do the basic operations needed for training and sampling.\n[View the complete Python script →](/quickstart.py.txt)\n\n## Creating the training client\n\nThe main object we'll be using is the `TrainingClient`, which corresponds to a fine-tuned model that we can train and sample from.\n\nFirst, set your Tinker API key environment variable. In the terminal where you'll run Python, or in your `.bashrc`, put `export TINKER_API_KEY=<your key>`.\n\nThen, create a `ServiceInterface`. This lets you find out what base models are available to be fine-tuned.\n\n```python\nimport tinker\nservice_client = tinker.ServiceClient()\nprint(\"Available models:\")\nfor item in service_client.get_server_capabilities().supported_models:\n    print(\"- \" + item.model_name)\n```\nYou'll see a list of model names:\n```\n- meta-llama/Llama-3.1-70B\n- meta-llama/Llama-3.1-8B\n...\n- Qwen/Qwen3-VL-30B-A3B-Instruct\n- Qwen/Qwen3-VL-235B-A22B-Instruct\n```\nWe currently support models from the Qwen3, Qwen3-VL, and Llama3 series. We'll use Qwen3-VL-30B-A3B-Instruct for these examples, as it's a vision-language model that can also handle text-only tasks. See [Available Models in Tinker](/model-lineup) for the full list.\n\nNow we can create the `TrainingClient`:\n```python\nbase_model = \"Qwen/Qwen3-VL-30B-A3B-Instruct\"\ntraining_client = service_client.create_lora_training_client(\n    base_model=base_model\n)\n```\nAs the name suggests, this model was already finetuned for chat/instruction-following. You should check the details of the model you're using in their system cards.\n\n## Preparing the training data\n\nNow we can do training updates on the model. This quickstart example won't show best practices for LLM fine-tuning; it's just an API demo. Check out [Rendering](/rendering), [Supervised Fine-tuning](/supervised-learning) and the other Cookbook examples for guidance on how to use Tinker in real applications.\n\nFor this model, we'll train a model that can translate words into Pig Latin. The rules for Pig Latin are simple:\n- If a word begins with a consonant, move it to the end and add \"ay\"\n- If a word begins with a vowel, just add \"way\" to the end\n\nHere are some example completions we'd like the model to perform, where the prompt is in green and the model's completion is in red:\n\n<div className=\"example\">\n<span className=\"prompt\">English: hello world<br/>\nPig Latin: </span><span className=\"completion\">ello-hay orld-way</span>\n</div>\n\nLet's create some training examples and convert them to a format expected by Tinker.\n\n```python\n# Create some training examples\nexamples = [\n    {\"input\": \"banana split\", \"output\": \"anana-bay plit-say\"},\n    {\"input\": \"quantum physics\", \"output\": \"uantum-qay ysics-phay\"},\n    {\"input\": \"donut shop\", \"output\": \"onut-day op-shay\"},\n    {\"input\": \"pickle jar\", \"output\": \"ickle-pay ar-jay\"},\n    {\"input\": \"space exploration\", \"output\": \"ace-spay exploration-way\"},\n    {\"input\": \"rubber duck\", \"output\": \"ubber-ray uck-day\"},\n    {\"input\": \"coding wizard\", \"output\": \"oding-cay izard-way\"},\n]\n\n# Convert examples into the format expected by the training client\nfrom tinker import types\n\n# Get the tokenizer from the training client\ntokenizer = training_client.get_tokenizer()\n\ndef process_example(example: dict, tokenizer) -> types.Datum:\n    # Format the input with Input/Output template\n    # For most real use cases, you'll want to use a renderer / chat template,\n    # (see later docs) but here, we'll keep it simple.\n    prompt = f\"English: {example['input']}\\nPig Latin:\"\n\n    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)\n    prompt_weights = [0] * len(prompt_tokens)\n    # Add a space before the output string, and finish with double newline\n    completion_tokens = tokenizer.encode(f\" {example['output']}\\n\\n\", add_special_tokens=False)\n    completion_weights = [1] * len(completion_tokens)\n\n    tokens = prompt_tokens + completion_tokens\n    weights = prompt_weights + completion_weights\n\n    input_tokens = tokens[:-1]\n    target_tokens = tokens[1:] # We're predicting the next token, so targets need to be shifted.\n    weights = weights[1:]\n\n    # A datum is a single training example for the loss function.\n    # It has model_input, which is the input sequence that'll be passed into the LLM,\n    # loss_fn_inputs, which is a dictionary of extra inputs used by the loss function.\n    return types.Datum(\n        model_input=types.ModelInput.from_ints(tokens=input_tokens),\n        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)\n    )\n\nprocessed_examples = [process_example(ex, tokenizer) for ex in examples]\n\n# Visualize the first example for debugging purposes\ndatum0 = processed_examples[0]\nprint(f\"{'Input':<20} {'Target':<20} {'Weight':<10}\")\nprint(\"-\" * 50)\nfor i, (inp, tgt, wgt) in enumerate(zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(), datum0.loss_fn_inputs['weights'].tolist())):\n    print(f\"{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}\")\n```\n\nThe visualization of the first example is:\n\n```\nInput                Target               Weight\n--------------------------------------------------\n'English'            ':'                  0\n':'                  ' banana'            0\n...\n' Latin'             ':'                  0\n':'                  ' an'                1\n' an'                'ana'                1\n...\n'ay'                 '\\n\\n'               1\n```\n\n## Vision inputs\n\nThe above example is text-only, but adding vision inputs is also straightforward. The `ModelInput` type takes a list of chunks, which can be either `EncodedTextChunk` or `ImageChunk`. For instance:\n\n```python\nimage_data = requests.get(\"https://thinkingmachines.ai/blog/on-policy-distillation/images/chess.png\").content\nmodel_input = tinker.ModelInput(chunks=[\n  types.EncodedTextChunk(tokens=tokenizer.encode(\"<|im_start|>user\\n<|vision_start|>\")),\n  types.ImageChunk(data=image_data, format=\"png\"),\n  types.EncodedTextChunk(tokens=tokenizer.encode(\"<|vision_end|>What is this?<|im_end|>\\n<|im_start|>assistant\\n\")),\n])\n```\n\nNote that Qwen3-VL was trained with special tokens like `<|vision_start|>` and `<|vision_end|>`. The cookbook's `Qwen3VLRenderer` handles these automatically—see [Rendering: Vision Inputs](/rendering#vision-inputs) for details and a complete example.\n\n## Performing a training update\n\nNow we can use this data to perform a training update. We'll do 6 updates on the same batch of data. (Note that this is not typically a good way to train!)\n\n```python\nimport numpy as np\nfor _ in range(6):\n    fwdbwd_future = training_client.forward_backward(processed_examples, \"cross_entropy\")\n    optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))\n\n    # Wait for the results\n    fwdbwd_result = fwdbwd_future.result()\n    optim_result = optim_future.result()\n\n    # fwdbwd_result contains the logprobs of all the tokens we put in. Now we can compute the weighted\n    # average log loss per token.\n    logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])\n    weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])\n    print(f\"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}\")\n```\n\nNote that the `forward_backward` and `optim_step` functions immediately return *futures*, which acknowledge that the task has been queued up by the server. For improved speed, we submitted both operations before waiting for the result by calling `result()` on the futures.\n\n\n## Sampling from the model\n\nNow we can test our model by sampling from it. In this case, we'll translate the phrase \"coffee break\" into Pig Latin.\n\n```python\n# First, create a sampling client. We need to transfer weights\nsampling_client = training_client.save_weights_and_get_sampling_client(name='pig-latin-model')\n\n# Now, we can sample from the model.\nprompt = types.ModelInput.from_ints(tokenizer.encode(\"English: coffee break\\nPig Latin:\"))\nparams = types.SamplingParams(max_tokens=20, temperature=0.0, stop=[\"\\n\"]) # Greedy sampling\nfuture = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)\nresult = future.result()\nprint(\"Responses:\")\nfor i, seq in enumerate(result.sequences):\n    print(f\"{i}: {repr(tokenizer.decode(seq.tokens))}\")\n```\n\nSince sampling is nondeterministic (sadly, even with temperature=0.0, [due to batching](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/)), the output will be different each time. You should see something like:\n\n```\nResponses:\n0: ' offe-bay eak-bay\\n\\n'\n1: ' offey-coy eak-bray\\n\\n'\n2: ' offecay eakbray\\n\\n'\n...\n```\n\n### Computing logprobs for a sequence\n\nWe can use the sampler to compute logprobs for a given sequence as well. This uses the prefill step and is returned as _prompt logprobs_.\n\n```python\nprompt = types.ModelInput.from_ints(tokenizer.encode(\"How many r's are in the word strawberry?\"))\nsample_response = sampling_client.sample(\n    prompt=prompt,\n    num_samples=1,\n    sampling_params=tinker.SamplingParams(max_tokens=1),  # Must be at least 1 token, represents prefill step\n    include_prompt_logprobs=True,\n).result()\n\n# example: [None, -9.5, -1.6, -8.8, -3.5, -8.3, ...]\nprint(sample_response.prompt_logprobs)\n```\n\nThe first logprob is `None` (corresponding to the first token), and subsequent entries are logprobs of each token in the prompt.\n\nThe sampling client also has a helper function, which is the same as above:\n\n```python\nsampling_client.compute_logprobs(prompt).result()\n```\n\n### Top-k logprobs\n\nFor distillation, it may be especially useful to compute _top-k logprobs_ for each token as well, which can get you a sense for what the model \"would have said\" after each prefix instead of the actual prompt.\n\n```python\nsample_response = sampling_client.sample(\n    prompt=prompt,\n    num_samples=1,\n    sampling_params=tinker.SamplingParams(max_tokens=1),\n    include_prompt_logprobs=True,\n    topk_prompt_logprobs=5,\n).result()\n\n# example: [None, [(14924, -1.2), (755, -2.2), ...], [(25, -1.6), (3137, -2.4), ...], ...]\nsample_response.topk_prompt_logprobs\n```\n\nFor each position in the response, this returns a list of `(token_id, logprob)` pairs for the top-k most likely tokens at that position.\n\n## Putting it together: Sampling from an image\n\nHere's a complete example that creates a training client, saves weights for sampling, and asks a question about an image:\n\n```python\nimport requests\nimport tinker\nfrom transformers import AutoTokenizer\n\nmodel_name = \"Qwen/Qwen3-VL-30B-A3B-Instruct\"\ntokenizer = AutoTokenizer.from_pretrained(model_name)\n\nservice_client = tinker.ServiceClient()\ntraining_client = await service_client.create_lora_training_client_async(base_model=model_name, rank=32)\nsampling_client = await training_client.save_weights_and_get_sampling_client_async(name=\"sampler\")\n\n# Grab an image and ask a question\nimage_data = requests.get(\"https://thinkingmachines.ai/blog/on-policy-distillation/images/chess.png\").content\nmodel_input = tinker.ModelInput(chunks=[\n    tinker.types.EncodedTextChunk(tokens=tokenizer.encode(\"<|im_start|>user\\n<|vision_start|>\")),\n    tinker.types.ImageChunk(data=image_data, format=\"png\"),\n    tinker.types.EncodedTextChunk(tokens=tokenizer.encode(\"<|vision_end|>What is this?<|im_end|>\\n<|im_start|>assistant\\n\")),\n])\n\nresult = await sampling_client.sample_async(prompt=model_input, num_samples=1, sampling_params=tinker.types.SamplingParams(max_tokens=100))\nprint(tokenizer.decode(result.sequences[0].tokens))\n```\n\nFor higher-level abstractions that handle special tokens automatically, see [Rendering: Vision Inputs](/rendering#vision-inputs).\n"
  },
  {
    "path": "docs/under-the-hood.mdx",
    "content": "# Under the Hood\n\nThis page explains some implementation details of Tinker, which are important for understanding how to speed up your code.\n\n## Clock Cycles\n\nIn Tinker, after you call `ServiceClient.create_lora_training_client`, your training job gets assigned to a pool of machines that work together -- a *worker pool* -- which are doing forward-backward operations repeatedly in lock-step.\nEach of these steps of the worker pool is called a *clock cycle*.\nIn each clock cycle, we do forward-backward and an optimizer step operation, each of which may involve multiple LoRA models that are being trained by this pool.\nYou can think of this pool as a single large training run that is time-shared between multiple different LoRA models, often from different users.\n\nWith multi-tenancy -- sharing the same worker pool between multiple models -- we can run the training system efficiently even if users are training with small batch sizes, or if they have other delays in their training loops that would otherwise leave the worker pool idle. Small batch sizes can often give better *sample efficiency*, so this setup lets us achieve both high compute efficiency and high sample efficiency.\n\nThe downside is that it can sometimes lead to worse *latency*: even if training with a small batch, you'll still see the same step time as a large batch. (Still, note that we'll only charge you for the compute you use.) Also, if your training loop is implemented naively, you might have to wait multiple clock cycles per batch, because you might miss a clock cycle between operations.\n\n### Overlapping `forward_backward` and `optim_step` Requests\n\nAs mentioned in the [Async and Futures](/async) section, you should submit your `forward_backward` and `optim_step` requests together before waiting for either of them. This way, they'll end up on the same clock cycle. If you write the code naively, you'll end up using *three* clock cycles per training step. Here's a recap of the example from the [Async and Futures](/async) section:\n\n**❌ Naive implementation (uses 3 clock cycles):**\n```python\n# Submit forward_backward, gets queued for clock cycle N\nfwd_bwd_future = await client.forward_backward_async(batch, loss_fn)\n\n# Wait for it to complete, and for client to receive the result\n# Due to communication latency, this happens a little after cycle N+1 started\nfwd_bwd_result = await fwd_bwd_future\n\n# Submit optim_step, gets queued for clock cycle N+2\noptim_future = await client.optim_step_async(adam_params)\n\n# Wait for it to complete, and for client to receive the result\n# This happens a little after cycle N+2 finishes\noptim_result = await optim_future\n\n# Total: forward_backward on cycle N, optim_step on cycle N+2\n# This takes 3 clock cycles (plus the time we waited before cycle N started)\n```\n\n**✓ Better implementation (uses 1 clock cycle):**\n```python\n# Submit both requests immediately. They'll both be slotted into the same clock cycle N\nfwd_bwd_future = await client.forward_backward_async(batch, loss_fn)\noptim_future = await client.optim_step_async(adam_params)\n\n# Now wait for results - both operations happen on cycle N\nfwd_bwd_result = await fwd_bwd_future\noptim_result = await optim_future\n\n# Total: both operations on cycle N\n# This takes 1 clock cycle\n```\n\n### Pipelining to Maximize Clock Cycle Efficiency\n\nTo maximize efficiency and avoid missing clock cycles, you should **pipeline your training loop**: submit the next batch before waiting for the current batch to complete. This ensures there's always a request queued when a new clock cycle starts.\n\nWe've created a demonstration script that shows the difference between pipelined and non-pipelined training:\n\n[View the clock cycles demonstration script →](/clock_cycles.py.txt)\n\nThe script includes two versions:\n\n- **Non-pipelined**: Submits a batch, waits for it to complete, then submits the next. This approach typically wastes clock cycles because there's a gap between when one batch finishes and the next is submitted, often using 2 clock cycles per training step.\n\n- **Pipelined**: Submits the next batch *before* waiting for the previous batch to complete. This approach often uses exactly 1 clock cycle per step, achieving maximum efficiency. Though it might sometimes take more than 1 clock cycle per step if the server is heavily loaded, or due to subtleties of our current implementation. (For example, if there are no other users, we might start the clock cycle after receiving the first `forward_backward` but before receiving the `optim_step`. Then we'll do `optim_step` on the next cycle. This causes an extra clock cycle but doesn't cause a slowdown.)\n\nRunning the script will show you the performance comparison, including total time and clock cycles used. The pipelined version typically saves both time and clock cycles.\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"tinker_cookbook\"\ndynamic = [\"version\"]\ndescription = \"Implementations of post-training algorithms using the Tinker API\"\nreadme = \"README.md\"\nauthors = [\n{ name = \"Tinker authors\", email = \"tinker@thinkingmachines.ai\" },\n]\nlicense = \"Apache-2.0\"\nrequires-python = \">=3.11\"\ndependencies = [\n    \"aiohttp\",\n    \"anyio\",\n    \"blobfile\",\n    \"chz\",\n    \"cloudpickle\",\n    \"datasets>=2.14.0\",\n    \"huggingface_hub\",\n    \"numpy>=1.24.0\",\n    \"pillow\",\n    \"pydantic\",\n    \"rich\",\n    \"safetensors\",\n    \"termcolor\",\n    \"tiktoken>=0.12.0\", # Required for Kimi tokenizer\n    \"tinker>=0.9.0\",\n    \"torch>=2.0\",\n    \"tqdm\",\n    \"transformers>=4.57.6,<=5.3.0\",\n]\n\n[project.urls]\nHomepage = \"https://thinkingmachines.ai/tinker\"\nRepository = \"https://github.com/thinking-machines-lab/tinker-cookbook\"\nDocumentation = \"https://tinker-docs.thinkingmachines.ai/\"\n\n[project.optional-dependencies]\ndev = [\n    \"pytest\",\n    \"pytest-asyncio\",\n    \"pytest-timeout\",\n    \"ruff\",\n    \"pyright\",\n]\nmath-rl = [\n    \"math-verify\",\n    \"pylatexenc\",\n    \"sympy\",\n]\nmodal = [\n    \"modal\",\n]\nmultiplayer-rl = [\n    \"textarena>=0.7.4\",\n]\nvector-search = [\n    \"chromadb\",\n    \"google-genai\",\n    \"huggingface_hub\",\n]\nwandb = [\n    \"wandb\",\n    \"plotly\",\n]\nneptune-scale = [\n    \"neptune-scale>=0.27.0\",\n]\ntrackio = [\n    \"trackio<1.0.0\",\n]\nverifiers = [\n    \"verifiers>=0.1.9,<0.1.10\",\n    \"openai\",\n]\ninspect = [\n    \"inspect-ai\",\n    \"inspect-evals>=0.3.106\",\n]\nlitellm = [\n    \"litellm\",\n]\nall = [\n    \"tinker_cookbook[math-rl]\",\n    \"tinker_cookbook[modal]\",\n    \"tinker_cookbook[multiplayer-rl]\",\n    \"tinker_cookbook[vector-search]\",\n    \"tinker_cookbook[wandb]\",\n    \"tinker_cookbook[neptune-scale]\",\n    \"tinker_cookbook[trackio]\",\n    \"tinker_cookbook[verifiers]\",\n    \"tinker_cookbook[inspect]\",\n    \"tinker_cookbook[litellm]\",\n]\n\n[build-system]\nrequires = [\"hatchling\", \"hatch-vcs\"]\nbuild-backend = \"hatchling.build\"\n\n[tool.hatch.version]\nsource = \"vcs\"\ntag-pattern = \"v(?P<version>.*)\"\n\n[tool.hatch.build.hooks.vcs]\nversion-file = \"tinker_cookbook/_version.py\"\n\n[tool.hatch.build.targets.wheel]\npackages = [\"tinker_cookbook\"]\nexclude = [\"*_test.py\"]\n\n[tool.pytest.ini_options]\ntestpaths = [\"tinker_cookbook\", \"tests\"]\npython_files = [\"*_test.py\", \"test_*.py\"]\nnorecursedirs = [\"tinker_cookbook/scripts\"]\nmarkers = [\n    \"slow: marks tests as slow (deselect with '-m \\\"not slow\\\"')\",\n    \"integration: marks tests requiring TINKER_API_KEY and network access\",\n    \"downstream_compat: marks tests that verify public API contracts for downstream consumers\",\n    \"timeout: per-test timeout in seconds (requires pytest-timeout)\",\n]\n\n[tool.ruff]\nline-length = 100\nexclude = [\n    # Vendored from HuggingFace, kept identical to upstream\n    \"kimi-k2.5-hf-tokenizer/tool_declaration_ts.py\",\n]\n\n[tool.ruff.lint]\nselect = [\n    \"E\",    # pycodestyle errors\n    \"F\",    # pyflakes (unused imports, undefined names)\n    \"I\",    # isort (import sorting)\n    \"UP\",   # pyupgrade (modernize Python syntax)\n    \"B\",    # flake8-bugbear (common bugs and design problems)\n    \"SIM\",  # flake8-simplify (simplifiable code)\n    \"RUF\",  # Ruff-specific rules\n    \"C4\",   # flake8-comprehensions (unnecessary comprehensions)\n]\nignore = [\n    \"E501\",   # line too long (handled by formatter)\n    \"B028\",   # no explicit stacklevel in warnings\n    \"SIM108\", # ternary operator (can reduce readability)\n    \"UP007\",  # use X | Y for union types (not always clearer)\n    \"RUF001\", # ambiguous unicode character in string (intentional in tokenizer code)\n    \"RUF002\", # ambiguous unicode character in docstring\n    \"RUF003\", # ambiguous unicode character in comment\n    \"RUF005\", # collection-literal-concatenation (stylistic, not a bug)\n    \"RUF006\", # asyncio-dangling-task (false positives with our patterns)\n    \"RUF012\", # mutable-class-default (conflicts with chz dataclass patterns)\n    \"RUF046\", # unnecessary-cast-to-int (explicit casts aid readability)\n    \"B008\",   # function-call-in-default-argument (to be fixed in a follow-up)\n    \"B905\",   # zip-without-explicit-strict (to be tightened in a follow-up)\n    \"B023\",   # function-uses-loop-variable (to be fixed in a follow-up)\n    \"B027\",   # empty-method-without-abstract-decorator (intentional in base classes)\n    \"RUF059\", # unused-unpacked-variable (to be fixed in a follow-up)\n    \"SIM117\", # multiple-with-statements (can reduce readability with context managers)\n    \"RUF022\", # unsorted-dunder-all (we group __all__ by category, not alphabetically)\n]\n\n[tool.pyright]\ninclude = [\"tinker_cookbook\"]\nexclude = [\n    \".venv\",\n    # Vendored from HuggingFace, kept identical to upstream\n    \"kimi-k2.5-hf-tokenizer/tool_declaration_ts.py\",\n]\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/compare_sampling_training_logprobs.py",
    "content": "import asyncio\nimport logging\nimport time\nfrom functools import cache\n\nimport chz\nimport httpx\nimport pandas as pd\nimport tinker\nimport torch\nfrom tinker import AdamParams, ModelInput\n\nfrom tinker_cookbook.supervised.common import datum_from_model_input_weights\n\n\n@cache\ndef get_reference_document():\n    \"\"\"Download PyTorch's forward_ad.py file from a specific commit.\"\"\"\n    url = \"https://raw.githubusercontent.com/pytorch/pytorch/a10b765bf159a86fb2a0ad693c6b72e0c691e60b/torch/autograd/forward_ad.py\"\n    response = httpx.get(url)\n    response.raise_for_status()\n    return response.text\n\n\nasync def get_row(\n    model_name: str,\n    service_client: tinker.ServiceClient,\n    timeout_sec: float,\n    saved_path_for_trainer: str | None = None,\n    saved_path_for_sampler: str | None = None,\n    ttl_seconds: int | None = 604800,\n) -> dict:\n    async def _inner():\n        tstart = time.time()\n        print(f\"========== Testing {model_name} ==========\")\n        training_client = await service_client.create_lora_training_client_async(\n            base_model=model_name\n        )\n        if saved_path_for_trainer is not None:\n            await training_client.load_state_async(saved_path_for_trainer)\n        # First sample something\n        tokenizer = training_client.get_tokenizer()\n        tokens = tokenizer.encode(get_reference_document())\n        model_input = ModelInput.from_ints(tokens)\n        weights = torch.ones(len(tokens), dtype=torch.float32)\n        weights[0] = 0.0\n        datum = datum_from_model_input_weights(model_input, weights)\n        for _ in range(3 if saved_path_for_trainer is None else 0):\n            fwd_bwd_future = await training_client.forward_backward_async(\n                [datum], loss_fn=\"cross_entropy\"\n            )\n            optim_step_future = await training_client.optim_step_async(\n                adam_params=AdamParams(learning_rate=1e-3)\n            )\n            _fwd_bwd_result = await fwd_bwd_future.result_async()\n            _optim_step_result = await optim_step_future.result_async()\n        fwd_future = await training_client.forward_async([datum], loss_fn=\"cross_entropy\")\n        fwd_result = await fwd_future.result_async()\n        training_logprobs = fwd_result.loss_fn_outputs[0][\"logprobs\"].to_torch()\n        if saved_path_for_sampler is None:\n            state_for_trainer_future = await training_client.save_state_async(\n                name=\"tmp-checkpoint\", ttl_seconds=ttl_seconds\n            )\n            state_for_trainer = await state_for_trainer_future.result_async()\n            print(f\"Saved state for trainer: {state_for_trainer.path}\")\n            sampling_client = await training_client.save_weights_and_get_sampling_client_async(\n                name=\"tmp-checkpoint\"\n            )\n        else:\n            sampling_client = training_client.create_sampling_client(\n                model_path=saved_path_for_sampler\n            )\n        logprobs_response = await sampling_client.compute_logprobs_async(model_input)\n        sampling_logprobs = torch.tensor(logprobs_response[1:])\n        mse = ((sampling_logprobs - training_logprobs) ** 2).mean()\n\n        dur = time.time() - tstart\n        print(f\"Time taken: {dur:.1f} seconds\")\n        result = {\n            \"model_name\": model_name,\n            \"mse[sample, train]\": mse.item(),\n            \"time\": dur,\n        }\n        print(result)\n        return result\n\n    try:\n        return await asyncio.wait_for(_inner(), timeout=timeout_sec)\n    except TimeoutError:\n        print(f\"ERROR: Timeout after {timeout_sec} seconds for model {model_name}\")\n        return {\"model_name\": model_name, \"error\": \"TimeoutError\"}\n\n\n@chz.chz\nclass Config:\n    base_url: str | None = None\n    print_models: bool = False\n    model_names: list[str] | None = None\n    model_name_filter: list[str] | None = chz.field(default_factory=lambda: [\"loadtest\"])\n    state_for_trainer: str | None = None\n    state_for_sampler: str | None = None\n    ttl_seconds: int | None = 604800  # 7 days\n\n\nasync def main(config: Config):\n    logging.basicConfig(level=logging.INFO)\n    service_client = tinker.ServiceClient(base_url=config.base_url)\n\n    if config.model_names is None:\n        server_capabilities = await service_client.get_server_capabilities_async()\n        model_names = [\n            model_info.model_name\n            for model_info in server_capabilities.supported_models\n            if model_info.model_name is not None\n        ]\n        if config.print_models:\n            print(\"Available models:\")\n            for model_name in model_names:\n                print(f\"- {model_name}\")\n            return\n    else:\n        model_names = list(config.model_names)\n\n    def should_do_model(model_name: str) -> bool:\n        if not config.model_name_filter:\n            return True\n        return not any(x in model_name for x in config.model_name_filter)\n\n    model_names = [x for x in sorted(model_names) if should_do_model(x)]\n    print(f\"Model names: {model_names}\")\n    timeout_sec = 300.0\n    rows = await asyncio.gather(\n        *[\n            get_row(\n                model_name,\n                service_client,\n                timeout_sec,\n                config.state_for_trainer,\n                config.state_for_sampler,\n                config.ttl_seconds,\n            )\n            for model_name in model_names\n        ]\n    )\n\n    df = pd.DataFrame(rows)\n    # Ensure df has all required columns with NaN for missing values\n    required_columns = [\"model_name\", \"mse[sample, train]\", \"time\", \"error\"]\n    for col in required_columns:\n        if col not in df.columns:\n            df[col] = pd.NA\n    df = df[required_columns]\n\n    df.to_csv(\"/tmp/sampling_training_logprobs.csv\", index=False)\n    print(df.to_markdown())\n\n\nif __name__ == \"__main__\":\n    asyncio.run(chz.nested_entrypoint(main))\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "\"\"\"Pytest configuration for integration tests.\n\nRecipes NOT yet covered by integration tests:\n  - code_rl: requires external sandbox service (SandboxFusion)\n  - search_tool: requires running Chroma vector DB + embedding API\n  - verifiers_rl: requires verifiers framework environment\n  - if_rl: requires if_verifiable library + IFBench data\n  - rubric: needs generated JSONL data (has generate_data.py script)\n  - rl_basic, sl_basic, rl_loop, sl_loop: standalone tutorial scripts (not full recipes)\n  - prompt_distillation: needs a local JSONL data file\n  - harbor_rl: needs Modal + downloaded Harbor tasks\n\"\"\"\n\nimport os\n\nimport pytest\n\n\ndef pytest_collection_modifyitems(config, items):\n    \"\"\"Skip smoke tests locally when TINKER_API_KEY is not set. Fail on CI.\"\"\"\n    if os.environ.get(\"TINKER_API_KEY\"):\n        return\n\n    # Separate smoke tests from downstream_compat tests (which don't need API keys)\n    smoke_items = [item for item in items if \"downstream_compat\" not in str(item.fspath)]\n    if not smoke_items:\n        return\n\n    if os.environ.get(\"CI\"):\n        pytest.fail(\"TINKER_API_KEY is not set but CI=true — smoke tests require an API key\")\n    skip = pytest.mark.skip(\n        reason=\"TINKER_API_KEY not set (set it or run pytest tinker_cookbook/ for unit tests)\"\n    )\n    for item in smoke_items:\n        item.add_marker(skip)\n"
  },
  {
    "path": "tests/downstream_compat/__init__.py",
    "content": ""
  },
  {
    "path": "tests/downstream_compat/conftest.py",
    "content": "\"\"\"Auto-apply the downstream_compat marker to every test in this directory.\n\nThese tests verify that tinker-cookbook's public API surface remains compatible\nwith downstream consumers (e.g., the internal downstream projects). They are fast, require\nno API keys or GPU, and run on every PR.\n\nRun just these tests:\n    uv run pytest tests/downstream_compat/\n    uv run pytest -m downstream_compat\n\"\"\"\n\nimport pytest\n\n\ndef pytest_collection_modifyitems(config, items):\n    marker = pytest.mark.downstream_compat\n    for item in items:\n        if \"downstream_compat\" in str(item.fspath):\n            item.add_marker(marker)\n"
  },
  {
    "path": "tests/downstream_compat/sig_helpers.py",
    "content": "\"\"\"Helpers for checking function/method signatures in downstream compat tests.\"\"\"\n\nimport inspect\n\n\ndef get_param_names(func) -> list[str]:\n    \"\"\"Return parameter names (excluding 'self') for a function or method.\"\"\"\n    sig = inspect.signature(func)\n    return [name for name in sig.parameters if name != \"self\"]\n\n\ndef assert_params(func, expected_params: list[str]) -> None:\n    \"\"\"Assert that a function has exactly the expected parameter names (excluding 'self').\"\"\"\n    actual = get_param_names(func)\n    assert actual == expected_params, (\n        f\"{func.__qualname__}: expected params {expected_params}, got {actual}\"\n    )\n\n\ndef assert_params_subset(func, required_params: list[str]) -> None:\n    \"\"\"Assert that a function has at least the required parameter names (in order).\"\"\"\n    actual = get_param_names(func)\n    for param in required_params:\n        assert param in actual, (\n            f\"{func.__qualname__}: missing required param '{param}', has {actual}\"\n        )\n"
  },
  {
    "path": "tests/downstream_compat/test_checkpoint_utils.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.checkpoint_utils.\n\nValidates that checkpoint management types and functions remain stable.\n\"\"\"\n\nfrom dataclasses import fields\n\nfrom tinker_cookbook.checkpoint_utils import (\n    CheckpointRecord,\n    get_last_checkpoint,\n    load_checkpoints_file,\n    save_checkpoint,\n)\n\n\nclass TestCheckpointRecord:\n    def test_fields(self):\n        names = {f.name for f in fields(CheckpointRecord)}\n        expected = {\"name\", \"batch\", \"epoch\", \"final\", \"state_path\", \"sampler_path\", \"extra\"}\n        assert expected.issubset(names)\n\n    def test_constructable_minimal(self):\n        record = CheckpointRecord(name=\"step_100\")\n        assert record.name == \"step_100\"\n        assert record.batch is None\n        assert record.extra == {}\n\n    def test_to_dict(self):\n        record = CheckpointRecord(name=\"step_100\", batch=100)\n        d = record.to_dict()\n        assert isinstance(d, dict)\n        assert d[\"name\"] == \"step_100\"\n\n    def test_from_dict(self):\n        d = {\"name\": \"step_100\", \"batch\": 100}\n        record = CheckpointRecord.from_dict(d)\n        assert record.name == \"step_100\"\n        assert record.batch == 100\n\n    def test_roundtrip(self):\n        original = CheckpointRecord(name=\"step_50\", batch=50, epoch=1, final=False)\n        restored = CheckpointRecord.from_dict(original.to_dict())\n        assert restored.name == original.name\n        assert restored.batch == original.batch\n\n    def test_has_method(self):\n        record = CheckpointRecord(name=\"test\", extra={\"key\": \"value\"})\n        assert record.has(\"key\") is True\n        assert record.has(\"missing\") is False\n\n    def test_get_method(self):\n        record = CheckpointRecord(name=\"test\", extra={\"key\": \"value\"})\n        assert record.get(\"key\") == \"value\"\n\n\nclass TestCheckpointFunctions:\n    def test_load_checkpoints_file_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(load_checkpoints_file, [\"log_dir\"])\n\n    def test_get_last_checkpoint_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(get_last_checkpoint, [\"log_dir\", \"required_key\"])\n\n    def test_save_checkpoint_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params_subset\n\n        assert_params_subset(save_checkpoint, [\"training_client\", \"name\", \"log_path\", \"loop_state\"])\n\n    def test_save_checkpoint_async_exists(self):\n        from tinker_cookbook import checkpoint_utils\n\n        assert hasattr(checkpoint_utils, \"save_checkpoint_async\")\n\n    def test_checkpoint_record_has_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(CheckpointRecord.has, [\"key\"])\n\n    def test_checkpoint_record_get_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params_subset\n\n        assert_params_subset(CheckpointRecord.get, [\"key\"])\n"
  },
  {
    "path": "tests/downstream_compat/test_cli_and_hyperparam.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.cli_utils and hyperparam_utils.\n\nValidates that CLI utilities and hyperparameter functions remain stable.\n\"\"\"\n\nfrom tinker_cookbook.cli_utils import check_log_dir\nfrom tinker_cookbook.hyperparam_utils import (\n    get_lora_lr_over_full_finetune_lr,\n    get_lora_param_count,\n    get_lr,\n)\n\n\nclass TestCliUtils:\n    def test_check_log_dir_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(check_log_dir, [\"log_dir\", \"behavior_if_exists\"])\n\n\nclass TestHyperparamUtils:\n    def test_get_lr_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(get_lr, [\"model_name\", \"is_lora\"])\n\n    def test_get_lora_lr_over_full_finetune_lr_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(get_lora_lr_over_full_finetune_lr, [\"model_name\", \"lora_alpha\"])\n\n    def test_get_lora_param_count_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params_subset\n\n        assert_params_subset(get_lora_param_count, [\"model_name\", \"lora_rank\"])\n\n    def test_get_lr_returns_float(self):\n        lr = get_lr(\"Qwen/Qwen3-8B\", is_lora=True)\n        assert isinstance(lr, float)\n        assert lr > 0\n"
  },
  {
    "path": "tests/downstream_compat/test_completers.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.completers.\n\nValidates that completer interfaces and types remain stable.\n\"\"\"\n\nimport inspect\n\nfrom tinker_cookbook.completers import (\n    MessageCompleter,\n    StopCondition,\n    TinkerTokenCompleter,\n    TokenCompleter,\n    TokensWithLogprobs,\n)\n\n\nclass TestTokensWithLogprobs:\n    def test_fields(self):\n        t = TokensWithLogprobs(tokens=[1, 2, 3], maybe_logprobs=[0.1, 0.2, 0.3])\n        assert t.tokens == [1, 2, 3]\n        assert t.maybe_logprobs == [0.1, 0.2, 0.3]\n\n    def test_logprobs_property(self):\n        t = TokensWithLogprobs(tokens=[1], maybe_logprobs=[0.5])\n        assert t.logprobs == [0.5]\n\n    def test_logprobs_raises_when_none(self):\n        import pytest\n\n        t = TokensWithLogprobs(tokens=[1], maybe_logprobs=None)\n        with pytest.raises(ValueError):\n            _ = t.logprobs\n\n    def test_none_logprobs(self):\n        t = TokensWithLogprobs(tokens=[1, 2], maybe_logprobs=None)\n        assert t.maybe_logprobs is None\n\n\nclass TestTokenCompleter:\n    def test_is_callable(self):\n        assert callable(TokenCompleter)\n\n    def test_call_is_async(self):\n        assert inspect.iscoroutinefunction(TokenCompleter.__call__)\n\n    def test_call_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(TokenCompleter.__call__, [\"model_input\", \"stop\"])\n\n\nclass TestMessageCompleter:\n    def test_is_callable(self):\n        assert callable(MessageCompleter)\n\n    def test_call_is_async(self):\n        assert inspect.iscoroutinefunction(MessageCompleter.__call__)\n\n    def test_call_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(MessageCompleter.__call__, [\"messages\"])\n\n\nclass TestTinkerTokenCompleter:\n    def test_is_subclass_of_token_completer(self):\n        assert issubclass(TinkerTokenCompleter, TokenCompleter)\n\n    def test_has_expected_fields(self):\n        # TinkerTokenCompleter is a dataclass with these fields\n        annotations = TinkerTokenCompleter.__dataclass_fields__\n        assert \"sampling_client\" in annotations\n        assert \"max_tokens\" in annotations\n        assert \"temperature\" in annotations\n\n\nclass TestStopCondition:\n    def test_is_type_alias(self):\n        # StopCondition should accept list[str] or list[int]\n        val_str: StopCondition = [\"<stop>\"]\n        val_int: StopCondition = [0]\n        assert isinstance(val_str, list)\n        assert isinstance(val_int, list)\n"
  },
  {
    "path": "tests/downstream_compat/test_model_info.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.model_info.\n\nValidates that model metadata functions and ModelAttributes remain stable.\n\"\"\"\n\nfrom dataclasses import fields\n\nfrom tinker_cookbook.model_info import (\n    ModelAttributes,\n    get_model_attributes,\n    get_recommended_renderer_name,\n    get_recommended_renderer_names,\n)\n\n\nclass TestModelAttributes:\n    def test_fields(self):\n        names = {f.name for f in fields(ModelAttributes)}\n        expected = {\n            \"organization\",\n            \"version_str\",\n            \"size_str\",\n            \"is_chat\",\n            \"recommended_renderers\",\n            \"is_vl\",\n        }\n        assert expected.issubset(names)\n\n    def test_constructable(self):\n        attrs = ModelAttributes(\n            organization=\"test-org\",\n            version_str=\"1.0\",\n            size_str=\"8B\",\n            is_chat=True,\n            recommended_renderers=(\"qwen3\",),\n        )\n        assert attrs.organization == \"test-org\"\n        assert attrs.is_vl is False  # default\n\n\nclass TestModelInfoFunctions:\n    def test_get_model_attributes_returns_model_attributes(self):\n        attrs = get_model_attributes(\"Qwen/Qwen3-8B\")\n        assert isinstance(attrs, ModelAttributes)\n        assert attrs.organization == \"Qwen\"\n\n    def test_get_recommended_renderer_name_returns_string(self):\n        name = get_recommended_renderer_name(\"Qwen/Qwen3-8B\")\n        assert isinstance(name, str)\n        assert len(name) > 0\n\n    def test_get_recommended_renderer_names_returns_list(self):\n        names = get_recommended_renderer_names(\"Qwen/Qwen3-8B\")\n        assert isinstance(names, list)\n        assert all(isinstance(n, str) for n in names)\n        assert len(names) > 0\n\n    def test_recommended_renderer_name_is_first_of_names(self):\n        name = get_recommended_renderer_name(\"Qwen/Qwen3-8B\")\n        names = get_recommended_renderer_names(\"Qwen/Qwen3-8B\")\n        assert name == names[0]\n\n\nclass TestModelInfoSignatures:\n    def test_get_model_attributes_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(get_model_attributes, [\"model_name\"])\n\n    def test_get_recommended_renderer_name_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(get_recommended_renderer_name, [\"model_name\"])\n\n    def test_get_recommended_renderer_names_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(get_recommended_renderer_names, [\"model_name\"])\n"
  },
  {
    "path": "tests/downstream_compat/test_recipes.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.recipes.\n\nValidates that recipe modules used by downstream remain importable and have\nthe expected API surface.\n\"\"\"\n\nimport inspect\n\n# ---------------------------------------------------------------------------\n# recipes.math_rl\n# ---------------------------------------------------------------------------\n\n\nclass TestMathRL:\n    def test_math_env_importable(self):\n        from tinker_cookbook.recipes.math_rl import math_env\n\n        assert math_env is not None\n\n    def test_arithmetic_env_importable(self):\n        from tinker_cookbook.recipes.math_rl import arithmetic_env\n\n        assert arithmetic_env is not None\n\n    def test_get_math_dataset_builder(self):\n        from tinker_cookbook.recipes.math_rl.math_env import get_math_dataset_builder\n\n        assert callable(get_math_dataset_builder)\n\n    def test_math_env_classes(self):\n        from tinker_cookbook.recipes.math_rl.math_env import (\n            Gsm8kDatasetBuilder,\n            MathDatasetBuilder,\n        )\n\n        assert Gsm8kDatasetBuilder is not None\n        assert MathDatasetBuilder is not None\n\n    def test_math_grading_functions(self):\n        from tinker_cookbook.recipes.math_rl.math_grading import (\n            extract_boxed,\n            grade_answer,\n            normalize_answer,\n        )\n\n        assert callable(grade_answer)\n        assert callable(normalize_answer)\n        assert callable(extract_boxed)\n\n    def test_safe_grade(self):\n        from tinker_cookbook.recipes.math_rl.math_env import safe_grade\n\n        assert callable(safe_grade)\n\n\n# ---------------------------------------------------------------------------\n# recipes.code_rl\n# ---------------------------------------------------------------------------\n\n\nclass TestCodeRL:\n    def test_code_env_importable(self):\n        from tinker_cookbook.recipes.code_rl.code_env import DeepcoderDatasetBuilder\n\n        assert DeepcoderDatasetBuilder is not None\n\n\n# ---------------------------------------------------------------------------\n# recipes.chat_sl\n# ---------------------------------------------------------------------------\n\n\nclass TestChatSL:\n    def test_chat_datasets_importable(self):\n        from tinker_cookbook.recipes.chat_sl import chat_datasets\n\n        assert chat_datasets is not None\n\n    def test_tulu3_builder_exists(self):\n        from tinker_cookbook.recipes.chat_sl.chat_datasets import Tulu3Builder\n\n        assert Tulu3Builder is not None\n\n\n# ---------------------------------------------------------------------------\n# recipes.preference\n# ---------------------------------------------------------------------------\n\n\nclass TestPreference:\n    def test_dpo_train_importable(self):\n        from tinker_cookbook.recipes.preference.dpo.train import CLIConfig, cli_main\n\n        assert CLIConfig is not None\n        assert callable(cli_main)\n\n    def test_preference_datasets_importable(self):\n        from tinker_cookbook.recipes.preference.datasets import HHHComparisonBuilder\n\n        assert HHHComparisonBuilder is not None\n\n\n# ---------------------------------------------------------------------------\n# recipes.rl_basic and sl_basic (used by config_utils)\n# ---------------------------------------------------------------------------\n\n\nclass TestBasicRecipes:\n    def test_rl_basic_build_config(self):\n        from tinker_cookbook.recipes.rl_basic import build_config_blueprint\n\n        assert callable(build_config_blueprint)\n\n    def test_sl_basic_build_config(self):\n        from tinker_cookbook.recipes.sl_basic import build_config_blueprint\n\n        assert callable(build_config_blueprint)\n\n\n# ---------------------------------------------------------------------------\n# eval.evaluators\n# ---------------------------------------------------------------------------\n\n\nclass TestEvaluators:\n    def test_sampling_client_evaluator_importable(self):\n        from tinker_cookbook.eval.evaluators import SamplingClientEvaluator\n\n        assert SamplingClientEvaluator is not None\n\n    def test_training_client_evaluator_importable(self):\n        from tinker_cookbook.eval.evaluators import TrainingClientEvaluator\n\n        assert TrainingClientEvaluator is not None\n\n    def test_evaluator_builder_importable(self):\n        from tinker_cookbook.eval.evaluators import EvaluatorBuilder\n\n        assert EvaluatorBuilder is not None\n\n\n# ---------------------------------------------------------------------------\n# distillation.datasets (used by tibo)\n# ---------------------------------------------------------------------------\n\n\nclass TestDistillation:\n    def test_prompt_only_env_importable(self):\n        from tinker_cookbook.distillation.datasets import PromptOnlyEnv\n\n        assert PromptOnlyEnv is not None\n\n    def test_load_tulu3_prompts_importable(self):\n        from tinker_cookbook.distillation.datasets import load_tulu3_prompts\n\n        assert callable(load_tulu3_prompts)\n\n\n# ---------------------------------------------------------------------------\n# preference.types (used by rl_cli)\n# ---------------------------------------------------------------------------\n\n\nclass TestPreferenceTypes:\n    def test_preference_model_builder_importable(self):\n        from tinker_cookbook.preference.types import PreferenceModelBuilderFromChatRenderer\n\n        assert PreferenceModelBuilderFromChatRenderer is not None\n\n\n# ---------------------------------------------------------------------------\n# supervised.train (entry point)\n# ---------------------------------------------------------------------------\n\n\nclass TestSupervisedTrain:\n    def test_config_exists(self):\n        from tinker_cookbook.supervised.train import Config\n\n        assert Config is not None\n\n    def test_main_exists(self):\n        from tinker_cookbook.supervised.train import main\n\n        assert callable(main)\n\n    def test_main_is_async(self):\n        from tinker_cookbook.supervised.train import main\n\n        assert inspect.iscoroutinefunction(main)\n\n\n# ---------------------------------------------------------------------------\n# utils.lr_scheduling\n# ---------------------------------------------------------------------------\n\n\nclass TestLRScheduling:\n    def test_lr_schedule_importable(self):\n        from tinker_cookbook.utils.lr_scheduling import LRSchedule\n\n        assert LRSchedule is not None\n"
  },
  {
    "path": "tests/downstream_compat/test_renderers.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.renderers.\n\nValidates that the renderer public API surface — types, registry functions,\nfactory, renderer method signatures, and built-in renderer names — remains\nstable for downstream consumers.\n\"\"\"\n\nimport inspect\n\nimport pytest\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.renderers import (\n    ContentPart,\n    DeepSeekV3ThinkingRenderer,\n    GptOssRenderer,\n    ImagePart,\n    Message,\n    MessageDelta,\n    Qwen3Renderer,\n    RenderContext,\n    Renderer,\n    Role,\n    StreamingMessageHeader,\n    StreamingTextDelta,\n    StreamingThinkingDelta,\n    TextPart,\n    ThinkingPart,\n    ToolCall,\n    ToolSpec,\n    TrainOnWhat,\n    Utf8TokenDecoder,\n    ensure_text,\n    format_content_as_string,\n    get_registered_renderer_names,\n    get_text_content,\n    is_renderer_registered,\n    parse_content_blocks,\n    register_renderer,\n    unregister_renderer,\n)\nfrom tinker_cookbook.renderers.base import ensure_list\nfrom tinker_cookbook.renderers.kimi_k2 import KimiK2Renderer\nfrom tinker_cookbook.renderers.kimi_k25 import KimiK25Renderer\n\n# ---------------------------------------------------------------------------\n# Type exports\n# ---------------------------------------------------------------------------\n\n\nclass TestTypeExports:\n    \"\"\"Verify that all types used by downstream are importable.\"\"\"\n\n    def test_message_is_typed_dict(self):\n        # Message is used as a TypedDict / dict with role+content\n        msg: Message = {\"role\": \"user\", \"content\": \"hello\"}\n        assert msg[\"role\"] == \"user\"\n        assert msg[\"content\"] == \"hello\"\n\n    def test_text_part_constructable(self):\n        part = TextPart(type=\"text\", text=\"hello\")\n        assert part[\"text\"] == \"hello\"\n\n    def test_thinking_part_constructable(self):\n        part = ThinkingPart(type=\"thinking\", thinking=\"reasoning\")\n        assert part[\"thinking\"] == \"reasoning\"\n\n    def test_tool_call_constructable(self):\n        tc = ToolCall(\n            function=ToolCall.FunctionBody(\n                name=\"my_tool\",\n                arguments='{\"key\": \"value\"}',\n            ),\n            id=\"call_123\",\n        )\n        assert tc.function.name == \"my_tool\"\n        assert tc.id == \"call_123\"\n\n    def test_tool_spec_constructable(self):\n        spec = ToolSpec(\n            name=\"my_tool\",\n            description=\"A tool\",\n            parameters={\"type\": \"object\", \"properties\": {}},\n        )\n        assert spec[\"name\"] == \"my_tool\"\n\n    def test_train_on_what_has_expected_values(self):\n        # Downstream uses at least LAST_ASSISTANT_MESSAGE\n        assert hasattr(TrainOnWhat, \"LAST_ASSISTANT_MESSAGE\")\n\n    def test_streaming_types_importable(self):\n        # These are used by projects/tinker_chat\n        assert StreamingMessageHeader is not None\n        assert StreamingTextDelta is not None\n        assert StreamingThinkingDelta is not None\n        assert MessageDelta is not None\n\n    def test_content_part_types(self):\n        assert ContentPart is not None\n        assert ImagePart is not None\n        assert Role is not None\n\n    def test_utf8_token_decoder_importable(self):\n        assert Utf8TokenDecoder is not None\n\n    def test_render_context_importable(self):\n        assert RenderContext is not None\n\n\n# ---------------------------------------------------------------------------\n# Utility functions\n# ---------------------------------------------------------------------------\n\n\nclass TestUtilityFunctions:\n    def test_ensure_text_with_string(self):\n        assert ensure_text(\"hello\") == \"hello\"\n\n    def test_format_content_as_string(self):\n        result = format_content_as_string(\"hello\")\n        assert isinstance(result, str)\n\n    def test_get_text_content(self):\n        msg: Message = {\"role\": \"user\", \"content\": \"hello\"}\n        assert get_text_content(msg) == \"hello\"\n\n    def test_parse_content_blocks_exists(self):\n        assert callable(parse_content_blocks)\n\n    def test_ensure_list_importable(self):\n        # Used by downstream rust extensions\n        assert callable(ensure_list)\n\n\n# ---------------------------------------------------------------------------\n# Registry functions\n# ---------------------------------------------------------------------------\n\n\nclass TestRendererRegistry:\n    def test_register_and_unregister_roundtrip(self):\n        name = \"__test_downstream_compat_renderer__\"\n        assert not is_renderer_registered(name)\n\n        def factory(tokenizer, image_processor=None):  # type: ignore[no-untyped-def]\n            return Qwen3Renderer(tokenizer)\n\n        register_renderer(name, factory)\n        assert is_renderer_registered(name)\n        assert name in get_registered_renderer_names()\n\n        assert unregister_renderer(name) is True\n        assert not is_renderer_registered(name)\n\n    def test_unregister_nonexistent_returns_false(self):\n        assert unregister_renderer(\"__nonexistent__\") is False\n\n\n# ---------------------------------------------------------------------------\n# get_renderer: built-in renderer names\n# ---------------------------------------------------------------------------\n\n# These are the renderer names downstream projects depend on.\nEXPECTED_RENDERER_NAMES = [\n    \"role_colon\",\n    \"llama3\",\n    \"qwen3\",\n    \"qwen3_disable_thinking\",\n    \"qwen3_instruct\",\n    \"qwen3_5\",\n    \"qwen3_5_disable_thinking\",\n    \"deepseekv3\",\n    \"deepseekv3_disable_thinking\",\n    \"deepseekv3_thinking\",\n    \"kimi_k2\",\n    \"kimi_k25\",\n    \"kimi_k25_disable_thinking\",\n    \"gpt_oss_no_sysprompt\",\n    \"gpt_oss_low_reasoning\",\n    \"gpt_oss_medium_reasoning\",\n    \"gpt_oss_high_reasoning\",\n    \"nemotron3\",\n    \"nemotron3_disable_thinking\",\n]\n\n\n@pytest.mark.parametrize(\"renderer_name\", EXPECTED_RENDERER_NAMES)\ndef test_builtin_renderer_name_resolves(renderer_name):\n    \"\"\"get_renderer must not raise ValueError for any name downstream projects use.\"\"\"\n    # We don't actually instantiate (needs a real tokenizer), just verify the\n    # name is handled in the factory's dispatch logic.\n    src = inspect.getsource(renderers.get_renderer)\n    # VL renderers use a different code path but the name must still appear\n    assert renderer_name in src or renderer_name.replace(\"_\", \" \") in src, (\n        f\"Renderer name '{renderer_name}' not found in get_renderer dispatch\"\n    )\n\n\n# ---------------------------------------------------------------------------\n# Renderer abstract interface\n# ---------------------------------------------------------------------------\n\n\nclass TestRendererInterface:\n    \"\"\"Verify the Renderer ABC exposes the methods downstream calls.\"\"\"\n\n    def test_build_generation_prompt_is_method(self):\n        assert hasattr(Renderer, \"build_generation_prompt\")\n        assert callable(Renderer.build_generation_prompt)\n\n    def test_build_supervised_example_is_method(self):\n        # Downstream calls build_supervised_example (singular)\n        assert hasattr(Renderer, \"build_supervised_example\")\n\n    def test_build_supervised_examples_is_method(self):\n        # Some downstream code uses the plural form\n        assert hasattr(Renderer, \"build_supervised_examples\")\n\n    def test_parse_response_is_method(self):\n        assert hasattr(Renderer, \"parse_response\")\n\n    def test_get_stop_sequences_is_abstract(self):\n        assert hasattr(Renderer, \"get_stop_sequences\")\n\n    def test_has_extension_property(self):\n        assert hasattr(Renderer, \"has_extension_property\")\n\n    def test_tokenizer_attribute(self):\n        # Downstream accesses renderer.tokenizer\n        assert \"tokenizer\" in Renderer.__init__.__code__.co_varnames\n\n    def test_pickle_metadata_attributes(self):\n        # Downstream relies on pickle support\n        assert hasattr(Renderer, \"_renderer_name\")\n        assert hasattr(Renderer, \"_model_name\")\n        assert hasattr(Renderer, \"_has_image_processor\")\n\n\n# ---------------------------------------------------------------------------\n# Signature checks\n# ---------------------------------------------------------------------------\n\n\nclass TestSignatures:\n    \"\"\"Verify that key function signatures haven't changed.\"\"\"\n\n    def test_get_renderer_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(\n            renderers.get_renderer, [\"name\", \"tokenizer\", \"image_processor\", \"model_name\"]\n        )\n\n    def test_register_renderer_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(register_renderer, [\"name\", \"factory\"])\n\n    def test_unregister_renderer_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(unregister_renderer, [\"name\"])\n\n    def test_build_generation_prompt_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(Renderer.build_generation_prompt, [\"messages\", \"role\", \"prefill\"])\n\n    def test_build_supervised_example_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(Renderer.build_supervised_example, [\"messages\", \"train_on_what\"])\n\n    def test_parse_response_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(Renderer.parse_response, [\"response\"])\n\n    def test_ensure_text_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(ensure_text, [\"content\"])\n\n    def test_format_content_as_string_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(format_content_as_string, [\"content\", \"separator\"])\n\n    def test_get_text_content_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(get_text_content, [\"message\"])\n\n\n# ---------------------------------------------------------------------------\n# Specific renderer classes used directly by downstream\n# ---------------------------------------------------------------------------\n\n\nclass TestRendererClasses:\n    def test_deepseekv3_thinking_renderer_importable(self):\n        assert issubclass(DeepSeekV3ThinkingRenderer, Renderer)\n\n    def test_qwen3_renderer_importable(self):\n        assert issubclass(Qwen3Renderer, Renderer)\n\n    def test_gpt_oss_renderer_importable(self):\n        assert issubclass(GptOssRenderer, Renderer)\n\n    def test_kimi_k2_renderer_importable(self):\n        assert issubclass(KimiK2Renderer, Renderer)\n\n    def test_kimi_k25_renderer_importable(self):\n        assert issubclass(KimiK25Renderer, Renderer)\n"
  },
  {
    "path": "tests/downstream_compat/test_rl_train.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.rl.train and rl.data_processing.\n\nValidates that RL training entry points and data processing functions remain stable.\n\"\"\"\n\nimport inspect\n\nfrom tinker_cookbook.rl.data_processing import (\n    assemble_training_data,\n    compute_advantages,\n    trajectory_to_data,\n)\nfrom tinker_cookbook.rl.train import Config, main\n\n# ---------------------------------------------------------------------------\n# rl.train\n# ---------------------------------------------------------------------------\n\n\nclass TestRLTrainConfig:\n    def test_config_exists(self):\n        assert Config is not None\n\n    def test_main_exists(self):\n        assert callable(main)\n\n    def test_main_is_async(self):\n        assert inspect.iscoroutinefunction(main)\n\n    def test_main_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params_subset\n\n        assert_params_subset(main, [\"cfg\"])\n\n\n# ---------------------------------------------------------------------------\n# rl.data_processing\n# ---------------------------------------------------------------------------\n\n\nclass TestRLDataProcessing:\n    def test_compute_advantages_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(compute_advantages, [\"trajectory_groups_P\"])\n\n    def test_trajectory_to_data_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(trajectory_to_data, [\"traj\", \"traj_advantage\"])\n\n    def test_assemble_training_data_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(assemble_training_data, [\"trajectory_groups_P\", \"advantages_P\"])\n\n\n# ---------------------------------------------------------------------------\n# rl.metrics (used by tibo training code)\n# ---------------------------------------------------------------------------\n\n\nclass TestRLMetrics:\n    def test_metrics_importable(self):\n        from tinker_cookbook.rl.metrics import (\n            compute_kl_sample_train,\n            discounted_future_sum_vectorized,\n        )\n\n        assert callable(compute_kl_sample_train)\n        assert callable(discounted_future_sum_vectorized)\n\n    def test_metric_util_importable(self):\n        from tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics\n\n        assert RLTestSetEvaluator is not None\n        assert callable(compute_trajectory_metrics)\n\n\n# ---------------------------------------------------------------------------\n# rl.rollouts (used by web_search_tasks)\n# ---------------------------------------------------------------------------\n\n\nclass TestRLRollouts:\n    def test_do_single_rollout_importable(self):\n        from tinker_cookbook.rl.rollouts import do_single_rollout\n\n        assert callable(do_single_rollout)\n"
  },
  {
    "path": "tests/downstream_compat/test_rl_types.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.rl.types.\n\nValidates that the RL type system — Env, StepResult, Transition, Trajectory,\nTrajectoryGroup, EnvGroupBuilder, RLDataset, RLDatasetBuilder — remains stable.\n\"\"\"\n\nimport inspect\nfrom dataclasses import fields\n\nfrom tinker_cookbook.completers import TokensWithLogprobs\nfrom tinker_cookbook.rl.types import (\n    Action,\n    Env,\n    EnvGroupBuilder,\n    Logs,\n    Metrics,\n    Observation,\n    RLDataset,\n    RLDatasetBuilder,\n    StepResult,\n    Trajectory,\n    TrajectoryGroup,\n    Transition,\n)\n\n# ---------------------------------------------------------------------------\n# Type aliases\n# ---------------------------------------------------------------------------\n\n\nclass TestTypeAliases:\n    def test_action_is_list_int(self):\n        val: Action = [1, 2, 3]\n        assert isinstance(val, list)\n\n    def test_metrics_is_dict(self):\n        val: Metrics = {\"acc\": 0.5}\n        assert isinstance(val, dict)\n\n    def test_logs_is_dict(self):\n        val: Logs = {\"msg\": \"ok\", \"step\": 1}\n        assert isinstance(val, dict)\n\n    def test_observation_alias_exists(self):\n        assert Observation is not None\n\n\n# ---------------------------------------------------------------------------\n# StepResult\n# ---------------------------------------------------------------------------\n\n\nclass TestStepResult:\n    def test_fields(self):\n        names = {f.name for f in fields(StepResult)}\n        expected = {\n            \"reward\",\n            \"episode_done\",\n            \"next_observation\",\n            \"next_stop_condition\",\n            \"metrics\",\n            \"logs\",\n        }\n        assert expected.issubset(names)\n\n    def test_metrics_defaults_to_empty(self):\n        sr = StepResult(\n            reward=1.0,\n            episode_done=False,\n            next_observation=None,  # type: ignore[arg-type]\n            next_stop_condition=[],\n        )\n        assert sr.metrics == {}\n        assert sr.logs == {}\n\n\n# ---------------------------------------------------------------------------\n# Transition\n# ---------------------------------------------------------------------------\n\n\nclass TestTransition:\n    def test_fields(self):\n        names = {f.name for f in fields(Transition)}\n        expected = {\"ob\", \"ac\", \"reward\", \"episode_done\", \"metrics\", \"logs\"}\n        assert expected.issubset(names)\n\n    def test_constructable(self):\n        t = Transition(\n            ob=None,  # type: ignore[arg-type]\n            ac=TokensWithLogprobs(tokens=[1, 2], maybe_logprobs=None),\n            reward=0.5,\n            episode_done=False,\n        )\n        assert t.reward == 0.5\n\n\n# ---------------------------------------------------------------------------\n# Env\n# ---------------------------------------------------------------------------\n\n\nclass TestEnv:\n    def test_is_abstract(self):\n        assert inspect.isabstract(Env)\n\n    def test_has_initial_observation(self):\n        assert hasattr(Env, \"initial_observation\")\n        assert inspect.iscoroutinefunction(Env.initial_observation)\n\n    def test_initial_observation_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(Env.initial_observation, [])\n\n    def test_has_step(self):\n        assert hasattr(Env, \"step\")\n        assert inspect.iscoroutinefunction(Env.step)\n\n    def test_step_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(Env.step, [\"action\"])\n\n\n# ---------------------------------------------------------------------------\n# Trajectory and TrajectoryGroup\n# ---------------------------------------------------------------------------\n\n\nclass TestTrajectory:\n    def test_fields(self):\n        names = {f.name for f in fields(Trajectory)}\n        assert \"transitions\" in names\n        assert \"final_ob\" in names\n\n    def test_frozen(self):\n        assert Trajectory.__dataclass_params__.frozen  # type: ignore[attr-defined]\n\n\nclass TestTrajectoryGroup:\n    def test_has_trajectories_field(self):\n        names = {f.name for f in fields(TrajectoryGroup)}\n        assert \"trajectories_G\" in names\n        assert \"metrics_G\" in names\n\n    def test_get_total_rewards_method(self):\n        assert hasattr(TrajectoryGroup, \"get_total_rewards\")\n        assert callable(TrajectoryGroup.get_total_rewards)\n\n\n# ---------------------------------------------------------------------------\n# EnvGroupBuilder\n# ---------------------------------------------------------------------------\n\n\nclass TestEnvGroupBuilder:\n    def test_is_abstract(self):\n        assert inspect.isabstract(EnvGroupBuilder)\n\n    def test_make_envs_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert hasattr(EnvGroupBuilder, \"make_envs\")\n        assert inspect.iscoroutinefunction(EnvGroupBuilder.make_envs)\n        assert_params(EnvGroupBuilder.make_envs, [])\n\n    def test_compute_group_rewards_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert hasattr(EnvGroupBuilder, \"compute_group_rewards\")\n        assert inspect.iscoroutinefunction(EnvGroupBuilder.compute_group_rewards)\n        assert_params(EnvGroupBuilder.compute_group_rewards, [\"trajectory_group\", \"env_group\"])\n\n    def test_logging_tags_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert hasattr(EnvGroupBuilder, \"logging_tags\")\n        assert_params(EnvGroupBuilder.logging_tags, [])\n\n\n# ---------------------------------------------------------------------------\n# RLDataset and RLDatasetBuilder\n# ---------------------------------------------------------------------------\n\n\nclass TestRLDataset:\n    def test_is_abstract(self):\n        assert inspect.isabstract(RLDataset)\n\n    def test_get_batch_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert hasattr(RLDataset, \"get_batch\")\n        assert_params(RLDataset.get_batch, [\"index\"])\n\n    def test_has_len(self):\n        assert hasattr(RLDataset, \"__len__\")\n\n\nclass TestRLDatasetBuilder:\n    def test_has_call(self):\n        assert callable(RLDatasetBuilder)\n"
  },
  {
    "path": "tests/downstream_compat/test_supervised.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.supervised.\n\nValidates that supervised training types and data utilities remain stable.\n\"\"\"\n\nfrom tinker_cookbook.supervised.data import (\n    FromConversationFileBuilder,\n    StreamingSupervisedDatasetFromHFDataset,\n    SupervisedDatasetFromHFDataset,\n    conversation_to_datum,\n)\nfrom tinker_cookbook.supervised.types import (\n    ChatDatasetBuilder,\n    ChatDatasetBuilderCommonConfig,\n    SupervisedDataset,\n    SupervisedDatasetBuilder,\n)\n\n\nclass TestSupervisedTypes:\n    def test_supervised_dataset_has_get_batch(self):\n        assert hasattr(SupervisedDataset, \"get_batch\")\n\n    def test_supervised_dataset_has_len(self):\n        assert hasattr(SupervisedDataset, \"__len__\")\n\n    def test_supervised_dataset_builder_is_callable(self):\n        assert callable(SupervisedDatasetBuilder)\n\n    def test_chat_dataset_builder_is_subclass(self):\n        assert issubclass(ChatDatasetBuilder, SupervisedDatasetBuilder)\n\n    def test_chat_dataset_builder_common_config_exists(self):\n        assert ChatDatasetBuilderCommonConfig is not None\n\n\nclass TestSupervisedData:\n    def test_conversation_to_datum_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(\n            conversation_to_datum, [\"conversation\", \"renderer\", \"max_length\", \"train_on_what\"]\n        )\n\n    def test_from_conversation_file_builder_exists(self):\n        assert FromConversationFileBuilder is not None\n\n    def test_supervised_dataset_from_hf_exists(self):\n        assert SupervisedDatasetFromHFDataset is not None\n\n    def test_streaming_supervised_dataset_from_hf_exists(self):\n        assert StreamingSupervisedDatasetFromHFDataset is not None\n"
  },
  {
    "path": "tests/downstream_compat/test_tokenizer_utils.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.tokenizer_utils.\n\nValidates that the tokenizer registry API remains stable.\n\"\"\"\n\nfrom tinker_cookbook.tokenizer_utils import (\n    Tokenizer,\n    get_registered_tokenizer_names,\n    get_tokenizer,\n    is_tokenizer_registered,\n    register_tokenizer,\n    unregister_tokenizer,\n)\n\n\nclass TestTokenizerRegistryAPI:\n    def test_get_tokenizer_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(get_tokenizer, [\"model_name\"])\n\n    def test_register_tokenizer_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(register_tokenizer, [\"name\", \"factory\"])\n\n    def test_unregister_tokenizer_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(unregister_tokenizer, [\"name\"])\n\n    def test_is_tokenizer_registered_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(is_tokenizer_registered, [\"name\"])\n\n    def test_get_registered_tokenizer_names_callable(self):\n        assert callable(get_registered_tokenizer_names)\n\n    def test_tokenizer_type_alias_exists(self):\n        assert Tokenizer is not None\n\n    def test_register_and_unregister_roundtrip(self):\n        name = \"__test_downstream_compat_tokenizer__\"\n        assert not is_tokenizer_registered(name)\n\n        register_tokenizer(name, lambda: None)  # type: ignore[arg-type]\n        assert is_tokenizer_registered(name)\n        assert name in get_registered_tokenizer_names()\n\n        assert unregister_tokenizer(name) is True\n        assert not is_tokenizer_registered(name)\n"
  },
  {
    "path": "tests/downstream_compat/test_utils.py",
    "content": "\"\"\"Downstream compatibility tests for tinker_cookbook.utils.\n\nValidates that logging, tracing, and misc utilities remain stable.\n\"\"\"\n\nimport inspect\n\nfrom tinker_cookbook.utils.misc_utils import (\n    all_same,\n    concat_lists,\n    dict_mean,\n    not_none,\n    safezip,\n    split_list,\n    timed,\n)\nfrom tinker_cookbook.utils.ml_log import (\n    JsonLogger,\n    Logger,\n    MultiplexLogger,\n    PrettyPrintLogger,\n    configure_logging_module,\n    dump_config,\n    setup_logging,\n)\n\n# ---------------------------------------------------------------------------\n# ml_log\n# ---------------------------------------------------------------------------\n\n\nclass TestLoggerHierarchy:\n    def test_logger_is_abstract(self):\n        assert inspect.isabstract(Logger)\n\n    def test_json_logger_is_subclass(self):\n        assert issubclass(JsonLogger, Logger)\n\n    def test_pretty_print_logger_is_subclass(self):\n        assert issubclass(PrettyPrintLogger, Logger)\n\n    def test_multiplex_logger_is_subclass(self):\n        assert issubclass(MultiplexLogger, Logger)\n\n    def test_setup_logging_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(\n            setup_logging,\n            [\"log_dir\", \"wandb_project\", \"wandb_name\", \"config\", \"do_configure_logging_module\"],\n        )\n\n    def test_configure_logging_module_signature(self):\n        from tests.downstream_compat.sig_helpers import assert_params\n\n        assert_params(configure_logging_module, [\"path\", \"level\"])\n\n    def test_dump_config_callable(self):\n        assert callable(dump_config)\n\n\n# ---------------------------------------------------------------------------\n# misc_utils\n# ---------------------------------------------------------------------------\n\n\nclass TestMiscUtils:\n    def test_dict_mean(self):\n        result = dict_mean([{\"a\": 1.0, \"b\": 2.0}, {\"a\": 3.0, \"b\": 4.0}])\n        assert result[\"a\"] == 2.0\n        assert result[\"b\"] == 3.0\n\n    def test_all_same_true(self):\n        assert all_same([1, 1, 1]) is True\n\n    def test_all_same_false(self):\n        assert all_same([1, 2, 1]) is False\n\n    def test_split_list(self):\n        result = split_list([1, 2, 3, 4], 2)\n        assert len(result) == 2\n\n    def test_concat_lists(self):\n        result = concat_lists([[1, 2], [3, 4]])\n        assert result == [1, 2, 3, 4]\n\n    def test_not_none(self):\n        assert not_none(42) == 42\n\n    def test_safezip(self):\n        result = list(safezip([1, 2], [3, 4]))\n        assert result == [(1, 3), (2, 4)]\n\n    def test_timed_is_context_manager(self):\n        assert callable(timed)\n\n\n# ---------------------------------------------------------------------------\n# trace (import-only check — used by tibo training code)\n# ---------------------------------------------------------------------------\n\n\nclass TestTraceImports:\n    def test_trace_importable(self):\n        from tinker_cookbook.utils.trace import scope, trace_init, update_scope_context\n\n        assert callable(scope)\n        assert callable(trace_init)\n        assert callable(update_scope_context)\n\n    def test_logtree_importable(self):\n        from tinker_cookbook.utils import logtree\n\n        assert logtree is not None\n"
  },
  {
    "path": "tests/helpers.py",
    "content": "\"\"\"Shared helpers for recipe smoke tests.\"\"\"\n\nimport os\nimport select\nimport subprocess\nimport time\n\nimport pytest\n\n# Timeout for each recipe (seconds). Override with SMOKE_TEST_TIMEOUT env var.\nDEFAULT_TIMEOUT = int(os.environ.get(\"SMOKE_TEST_TIMEOUT\", \"1800\"))\n\n# Default number of training steps for smoke tests.\nDEFAULT_MAX_STEPS = 2\n\n\ndef run_recipe(\n    module: str,\n    args: list[str] | None = None,\n    timeout: int = DEFAULT_TIMEOUT,\n    max_steps: int = DEFAULT_MAX_STEPS,\n):\n    \"\"\"Run a recipe module for a limited number of steps and verify clean exit.\n\n    Passes max_steps to the recipe so it exits naturally after N training steps.\n    Output is streamed to stdout in real time for debuggability in CI.\n\n    Args:\n        module: Python module path (e.g., \"tinker_cookbook.recipes.chat_sl.train\")\n        args: CLI arguments to pass to the module\n        timeout: Maximum seconds to wait for the recipe to complete\n        max_steps: Number of training steps to run (passed as CLI arg)\n    \"\"\"\n    cmd = [\"uv\", \"run\", \"python\", \"-m\", module] + (args or []) + [f\"max_steps={max_steps}\"]\n    print(f\"\\n>>> {' '.join(cmd)}\", flush=True)\n\n    proc = subprocess.Popen(\n        cmd,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.STDOUT,\n    )\n\n    output_lines: list[str] = []\n    start_time = time.monotonic()\n\n    try:\n        assert proc.stdout is not None\n        fd = proc.stdout.fileno()\n        while True:\n            elapsed = time.monotonic() - start_time\n            if elapsed >= timeout:\n                proc.terminate()\n                proc.wait(timeout=10)\n                last_lines = \"\\n\".join(output_lines[-30:])\n                pytest.fail(\n                    f\"Recipe {module} did not complete within {timeout}s \"\n                    f\"(exit code: {proc.returncode})\\n\\nLast 30 lines:\\n{last_lines}\"\n                )\n\n            # Check if process exited\n            if proc.poll() is not None:\n                # Drain remaining output\n                for line in proc.stdout:\n                    decoded = line.decode(\"utf-8\", errors=\"replace\").rstrip(\"\\n\")\n                    output_lines.append(decoded)\n                    print(decoded, flush=True)\n                break\n\n            # Wait up to 5s for output, then re-check timeout\n            ready, _, _ = select.select([fd], [], [], 5.0)\n            if not ready:\n                continue\n\n            line = proc.stdout.readline()\n            if line:\n                decoded = line.decode(\"utf-8\", errors=\"replace\").rstrip(\"\\n\")\n                output_lines.append(decoded)\n                print(decoded, flush=True)\n    except Exception:\n        # Ensure cleanup on unexpected errors\n        proc.kill()\n        proc.wait(timeout=10)\n        raise\n\n    elapsed = time.monotonic() - start_time\n\n    if proc.returncode == 0:\n        print(f\"\\n>>> PASSED: recipe completed cleanly in {elapsed:.0f}s\", flush=True)\n        return\n\n    # Non-zero exit code\n    last_lines = \"\\n\".join(output_lines[-30:])\n    pytest.fail(\n        f\"Recipe {module} failed with exit code {proc.returncode}\\n\\nLast 30 lines:\\n{last_lines}\"\n    )\n"
  },
  {
    "path": "tests/recipes/__init__.py",
    "content": ""
  },
  {
    "path": "tests/recipes/test_recipe_chat_sl.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\nMODULE = \"tinker_cookbook.recipes.chat_sl.train\"\nLOG_PATH = \"/tmp/tinker-smoke-test/chat_sl_resume\"\n\n\n@pytest.mark.integration\ndef test_chat_sl():\n    \"\"\"Train SFT from scratch for 2 steps, saving a checkpoint at step 1.\"\"\"\n    run_recipe(\n        MODULE,\n        [\n            \"behavior_if_log_dir_exists=delete\",\n            f\"log_path={LOG_PATH}\",\n            \"save_every=1\",\n        ],\n    )\n\n\n@pytest.mark.integration\ndef test_chat_sl_resume():\n    \"\"\"Resume SFT training from the checkpoint saved by test_chat_sl.\"\"\"\n    run_recipe(\n        MODULE,\n        [\n            \"behavior_if_log_dir_exists=resume\",\n            f\"log_path={LOG_PATH}\",\n            \"save_every=1\",\n        ],\n        max_steps=4,\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_dpo.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_dpo():\n    run_recipe(\n        \"tinker_cookbook.recipes.preference.dpo.train\",\n        [\"behavior_if_log_dir_exists=delete\"],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_guess_number.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_guess_number():\n    run_recipe(\n        \"tinker_cookbook.recipes.multiplayer_rl.guess_number.train\",\n        [\n            \"batch_size=8\",\n            \"group_size=2\",\n        ],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_math_rl.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\nMODULE = \"tinker_cookbook.recipes.math_rl.train\"\n\n\n@pytest.mark.integration\ndef test_math_rl_sync():\n    run_recipe(\n        MODULE,\n        [\n            \"model_name=Qwen/Qwen3.5-4B\",\n            \"groups_per_batch=8\",\n            \"group_size=4\",\n            \"max_tokens=5\",\n            \"behavior_if_log_dir_exists=delete\",\n        ],\n    )\n\n\n@pytest.mark.integration\ndef test_math_rl_async():\n    run_recipe(\n        MODULE,\n        [\n            \"model_name=Qwen/Qwen3.5-4B\",\n            \"groups_per_batch=8\",\n            \"group_size=4\",\n            \"max_tokens=5\",\n            \"max_steps_off_policy=2\",\n            \"behavior_if_log_dir_exists=delete\",\n        ],\n    )\n\n\n@pytest.mark.integration\ndef test_math_rl_stream_minibatch():\n    run_recipe(\n        MODULE,\n        [\n            \"model_name=Qwen/Qwen3.5-4B\",\n            \"groups_per_batch=8\",\n            \"group_size=4\",\n            \"max_tokens=5\",\n            \"stream_minibatch_config.groups_per_batch=8\",\n            \"stream_minibatch_config.num_minibatches=2\",\n            \"behavior_if_log_dir_exists=delete\",\n        ],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_off_policy_reasoning.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_off_policy_reasoning():\n    run_recipe(\n        \"tinker_cookbook.recipes.distillation.off_policy_reasoning\",\n        [\n            \"batch_size=16\",\n            \"max_prompts=128\",\n            \"buffer_size=128\",\n            \"behavior_if_log_dir_exists=delete\",\n        ],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_on_policy_distillation.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_on_policy_distillation():\n    run_recipe(\n        \"tinker_cookbook.recipes.distillation.on_policy_distillation\",\n        [\n            \"groups_per_batch=16\",\n            \"behavior_if_log_dir_exists=delete\",\n        ],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_on_policy_multi_teacher.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_on_policy_multi_teacher():\n    run_recipe(\n        \"tinker_cookbook.recipes.distillation.on_policy_multi_teacher\",\n        [\n            \"deepmath_groups_per_batch=16\",\n            \"tulu3_groups_per_batch=16\",\n            \"behavior_if_log_dir_exists=delete\",\n        ],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_rlhf_pipeline.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_rlhf_pipeline():\n    run_recipe(\n        \"tinker_cookbook.recipes.preference.rlhf.rlhf_pipeline\",\n        [\"short_name=smoke-test\"],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_shorter.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_shorter():\n    run_recipe(\n        \"tinker_cookbook.recipes.preference.shorter.train\",\n        [\"behavior_if_log_dir_exists=delete\"],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_text_arena.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_text_arena():\n    run_recipe(\n        \"tinker_cookbook.recipes.multiplayer_rl.text_arena.train\",\n        [\n            \"batch_size=16\",\n            \"num_train_datapoints=128\",\n        ],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_twenty_questions.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_twenty_questions():\n    run_recipe(\n        \"tinker_cookbook.recipes.multiplayer_rl.twenty_questions.train\",\n        [\n            \"batch_size=8\",\n            \"group_size=2\",\n            \"num_epochs=1\",\n        ],\n    )\n"
  },
  {
    "path": "tests/recipes/test_recipe_vlm_classifier.py",
    "content": "import pytest\n\nfrom tests.helpers import run_recipe\n\n\n@pytest.mark.integration\ndef test_vlm_classifier():\n    run_recipe(\n        \"tinker_cookbook.recipes.vlm_classifier.train\",\n        [\n            \"experiment_dir=/tmp/tinker-smoke-test/vlm_classifier\",\n            \"model_name=Qwen/Qwen3-VL-30B-A3B-Instruct\",\n            \"batch_size=16\",\n            \"num_epochs=1\",\n            \"n_eval=16\",\n            \"behavior_if_log_dir_exists=delete\",\n        ],\n    )\n"
  },
  {
    "path": "tests/test_inspect_eval.py",
    "content": "\"\"\"Smoke tests for inspect evaluation integration.\n\nTests the include_reasoning parameter across thinking and non-thinking models\nby calling api.generate() directly to verify the adapter returns the correct\ncontent types to inspect_ai.\n\nTest matrix:\n  - Thinking model (Qwen3) + include_reasoning=True  → [ContentReasoning, ContentText]\n  - Thinking model (Qwen3) + include_reasoning=False → plain string, no <think> tags\n  - Non-thinking model (Llama 3.1) + include_reasoning=True  → [ContentText] only\n  - Non-thinking model (Llama 3.1) + include_reasoning=False → plain string\n\"\"\"\n\nimport asyncio\n\nimport pytest\n\npytest.importorskip(\"inspect_ai\")\n\nimport tinker\nfrom inspect_ai.model import ChatMessage as InspectAIChatMessage\nfrom inspect_ai.model import ChatMessageUser as InspectAIChatMessageUser\nfrom inspect_ai.model import ContentReasoning as InspectAIContentReasoning\nfrom inspect_ai.model import ContentText as InspectAIContentText\nfrom inspect_ai.model import GenerateConfig as InspectAIGenerateConfig\nfrom inspect_ai.model import ModelOutput as InspectAIModelOutput\n\nfrom tinker_cookbook.eval.inspect_utils import InspectAPIFromTinkerSampling\n\nTHINKING_MODEL = \"Qwen/Qwen3-8B\"\nTHINKING_RENDERER = \"qwen3\"\n\nNON_THINKING_MODEL = \"meta-llama/Llama-3.1-8B-Instruct\"\nNON_THINKING_RENDERER = \"llama3\"\n\nPROMPT: list[InspectAIChatMessage] = [\n    InspectAIChatMessageUser(content=\"What is 1 + 1? Reply with just the number.\")\n]\nGENERATE_CONFIG = InspectAIGenerateConfig(temperature=0.6, max_tokens=1024)\n\n\ndef _create_api(\n    model_name: str, renderer_name: str, include_reasoning: bool\n) -> InspectAPIFromTinkerSampling:\n    service_client = tinker.ServiceClient()\n    sampling_client = service_client.create_sampling_client(base_model=model_name)\n    return InspectAPIFromTinkerSampling(\n        renderer_name=renderer_name,\n        model_name=model_name,\n        sampling_client=sampling_client,\n        include_reasoning=include_reasoning,\n    )\n\n\nasync def _generate(api: InspectAPIFromTinkerSampling) -> InspectAIModelOutput:\n    return await api.generate(input=PROMPT, tools=[], tool_choice=\"auto\", config=GENERATE_CONFIG)\n\n\ndef _log_response(result: InspectAIModelOutput) -> None:\n    \"\"\"Print response content for CI debuggability.\"\"\"\n    content = result.choices[0].message.content\n    print(f\"\\n  Content type: {type(content).__name__}\")\n    if isinstance(content, str):\n        print(f\"  Text: {content[:300]!r}\")\n    else:\n        for i, part in enumerate(content):\n            if isinstance(part, InspectAIContentReasoning):\n                print(f\"  Part {i} [ContentReasoning]: {part.reasoning[:200]!r}\")\n            elif isinstance(part, InspectAIContentText):\n                print(f\"  Part {i} [ContentText]: {part.text[:200]!r}\")\n            else:\n                print(f\"  Part {i} [{type(part).__name__}]: {repr(part)[:200]}\")\n    usage = result.usage\n    if usage:\n        print(f\"  Tokens: {usage.input_tokens} in, {usage.output_tokens} out\")\n\n\n@pytest.mark.integration\ndef test_thinking_model_include_reasoning():\n    \"\"\"Thinking model + include_reasoning=True: response has ContentReasoning + ContentText.\"\"\"\n    api = _create_api(THINKING_MODEL, THINKING_RENDERER, include_reasoning=True)\n    result = asyncio.run(_generate(api))\n    _log_response(result)\n\n    content = result.choices[0].message.content\n    assert isinstance(content, list), f\"Expected list content, got {type(content)}\"\n\n    reasoning_parts = [c for c in content if isinstance(c, InspectAIContentReasoning)]\n    text_parts = [c for c in content if isinstance(c, InspectAIContentText)]\n    assert len(reasoning_parts) > 0, \"Expected ContentReasoning from thinking model\"\n    assert len(text_parts) > 0, \"Expected ContentText from thinking model\"\n    assert len(reasoning_parts[0].reasoning) > 0, \"Reasoning content should not be empty\"\n\n\n@pytest.mark.integration\ndef test_thinking_model_exclude_reasoning():\n    \"\"\"Thinking model + include_reasoning=False: response is plain string without <think> tags.\"\"\"\n    api = _create_api(THINKING_MODEL, THINKING_RENDERER, include_reasoning=False)\n    result = asyncio.run(_generate(api))\n    _log_response(result)\n\n    content = result.choices[0].message.content\n    assert isinstance(content, str), f\"Expected string content, got {type(content)}\"\n    assert \"<think>\" not in content, \"Reasoning should be stripped from string content\"\n\n\n@pytest.mark.integration\ndef test_non_thinking_model_include_reasoning():\n    \"\"\"Non-thinking model + include_reasoning=True: response has ContentText only, no crash.\"\"\"\n    api = _create_api(NON_THINKING_MODEL, NON_THINKING_RENDERER, include_reasoning=True)\n    result = asyncio.run(_generate(api))\n    _log_response(result)\n\n    content = result.choices[0].message.content\n    assert isinstance(content, list), f\"Expected list content, got {type(content)}\"\n\n    reasoning_parts = [c for c in content if isinstance(c, InspectAIContentReasoning)]\n    text_parts = [c for c in content if isinstance(c, InspectAIContentText)]\n    assert len(reasoning_parts) == 0, \"Non-thinking model should not produce ContentReasoning\"\n    assert len(text_parts) > 0, \"Expected ContentText from non-thinking model\"\n\n\n@pytest.mark.integration\ndef test_non_thinking_model_exclude_reasoning():\n    \"\"\"Non-thinking model + include_reasoning=False: response is plain string (baseline).\"\"\"\n    api = _create_api(NON_THINKING_MODEL, NON_THINKING_RENDERER, include_reasoning=False)\n    result = asyncio.run(_generate(api))\n    _log_response(result)\n\n    content = result.choices[0].message.content\n    assert isinstance(content, str), f\"Expected string content, got {type(content)}\"\n"
  },
  {
    "path": "tests/test_modal_sandbox.py",
    "content": "\"\"\"Smoke tests for ModalSandbox.\n\nRequire Modal authentication and network access; skipped when Modal is not\nconfigured locally (no MODAL_TOKEN_ID env var and no ~/.modal.toml).\n\nThe primary goal is to catch latency regressions in write_file — a previous\nbug where a missing drain() after write_eof() caused hangs. For context,\nrun_command (no stdin) has always been fast; write_file should be comparable\nafter the drain fix, not 30-60x slower.\n\"\"\"\n\nimport asyncio\nimport os\nimport time\n\nimport modal\nimport pytest\nimport pytest_asyncio\n\nfrom tinker_cookbook.sandbox.modal_sandbox import ModalSandbox\n\n_has_modal_auth = bool(\n    os.environ.get(\"MODAL_TOKEN_ID\") or os.path.exists(os.path.expanduser(\"~/.modal.toml\"))\n)\n\nrequires_modal = pytest.mark.skipif(not _has_modal_auth, reason=\"Modal not configured locally\")\n\n# Modal's debian_slim() defaults to the local Python version, which may not\n# be supported. Pin to 3.12 for sandbox creation.\n_MODAL_IMAGE = modal.Image.debian_slim(python_version=\"3.12\")\n\n\n@pytest_asyncio.fixture(scope=\"module\")\nasync def sandbox():\n    \"\"\"Shared Modal sandbox for all tests in this module.\"\"\"\n    sb = await ModalSandbox.create(image=_MODAL_IMAGE, timeout=120)\n    yield sb\n    await sb.cleanup()\n\n\nasync def _timed(coro):\n    \"\"\"Await a coroutine and return (result, elapsed_seconds).\"\"\"\n    start = time.monotonic()\n    result = await coro\n    return result, time.monotonic() - start\n\n\n@requires_modal\n@pytest.mark.asyncio\n@pytest.mark.timeout(20)\nasync def test_write_file_latency(sandbox):\n    \"\"\"write_file should complete in seconds, not minutes.\n\n    Before the drain fix, write_file would hang for ~60s (the full exec timeout)\n    because proc.stdin.write_eof() wasn't flushed. This test catches that\n    regression by asserting a generous upper bound of 15s.\n    \"\"\"\n    content = \"#!/bin/bash\\necho hello world\\n\"\n\n    result, elapsed = await _timed(\n        sandbox.write_file(\"/tmp/test.sh\", content, executable=True, timeout=30)\n    )\n    assert result.exit_code == 0, f\"write_file failed: {result.stderr}\"\n    assert elapsed < 15, f\"write_file took {elapsed:.1f}s — likely stdin EOF hang (expected <15s)\"\n\n    # Verify content was written correctly\n    read_result = await sandbox.run_command(\"cat /tmp/test.sh\")\n    assert read_result.exit_code == 0\n    assert read_result.stdout == content\n\n    # Verify executable bit\n    stat_result = await sandbox.run_command(\"test -x /tmp/test.sh && echo yes\")\n    assert stat_result.stdout.strip() == \"yes\"\n\n    print(f\"\\nwrite_file latency: {elapsed:.2f}s\")\n\n\n@requires_modal\n@pytest.mark.asyncio\n@pytest.mark.timeout(20)\nasync def test_write_file_binary(sandbox):\n    \"\"\"write_file should handle binary content correctly.\"\"\"\n    content = bytes(range(256))\n\n    result, elapsed = await _timed(sandbox.write_file(\"/tmp/binary.bin\", content, timeout=30))\n    assert result.exit_code == 0, f\"write_file failed: {result.stderr}\"\n    assert elapsed < 15, f\"write_file took {elapsed:.1f}s — likely stdin EOF hang (expected <15s)\"\n\n    # Verify size\n    size_result = await sandbox.run_command(\"wc -c < /tmp/binary.bin\")\n    assert size_result.exit_code == 0\n    assert int(size_result.stdout.strip()) == 256\n\n    print(f\"\\nwrite_file (binary) latency: {elapsed:.2f}s\")\n\n\n# ---------------------------------------------------------------------------\n# cleanup() resilience tests\n# ---------------------------------------------------------------------------\n\n\n@requires_modal\n@pytest.mark.asyncio\n@pytest.mark.timeout(20)\nasync def test_cleanup_after_timeout():\n    \"\"\"cleanup() should not raise even if the sandbox has already timed out.\"\"\"\n    # The minimum timeout is 10 seconds.\n    sb = await ModalSandbox.create(image=_MODAL_IMAGE, timeout=10)\n\n    # Wait for the sandbox to time out\n    await asyncio.sleep(12)\n\n    # cleanup() should succeed without raising SandboxTimeoutError\n    await sb.cleanup()\n\n\n@requires_modal\n@pytest.mark.asyncio\n@pytest.mark.timeout(10)\nasync def test_cleanup_after_terminate():\n    \"\"\"cleanup() should not raise if called twice (sandbox already terminated).\"\"\"\n    sb = await ModalSandbox.create(image=_MODAL_IMAGE, timeout=60)\n\n    # First cleanup terminates normally\n    await sb.cleanup()\n\n    # Second cleanup should not raise even though sandbox is already dead\n    await sb.cleanup()\n\n\n@requires_modal\n@pytest.mark.asyncio\n@pytest.mark.timeout(20)\nasync def test_cleanup_after_command_timeout():\n    \"\"\"cleanup() should work after a command hits the sandbox timeout.\"\"\"\n    sb = await ModalSandbox.create(image=_MODAL_IMAGE, timeout=10)\n\n    # Run a command that will outlast the sandbox timeout\n    await sb.run_command(\"sleep 30\", timeout=30)\n\n    # cleanup() should not raise\n    await sb.cleanup()\n"
  },
  {
    "path": "tests/third_party/__init__.py",
    "content": ""
  },
  {
    "path": "tests/third_party/test_litellm.py",
    "content": "\"\"\"End-to-end smoke test for the LiteLLM Tinker provider.\n\nRequires TINKER_API_KEY to be set (skipped otherwise, see conftest.py).\n\"\"\"\n\nimport litellm\nimport pytest\nimport tinker\n\nimport tinker_cookbook.third_party.litellm.provider as provider_mod\nfrom tinker_cookbook.third_party.litellm import register_litellm_provider\n\n# Use a small model for fast smoke testing\nBASE_MODEL = \"Qwen/Qwen3-4B-Instruct-2507\"\n\n\n@pytest.fixture(scope=\"module\")\ndef tinker_provider():\n    # Reset singleton so the module gets a fresh provider\n    old = provider_mod._registered_provider\n    provider_mod._registered_provider = None\n\n    provider = register_litellm_provider()\n    yield provider\n\n    # Clean up\n    litellm.custom_provider_map[:] = [\n        entry for entry in litellm.custom_provider_map if entry[\"custom_handler\"] is not provider\n    ]\n    provider_mod._registered_provider = old\n\n\n# ---------------------------------------------------------------------------\n# Pretrained model tests\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.integration\n@pytest.mark.asyncio\nasync def test_acompletion_basic(tinker_provider) -> None:\n    \"\"\"Basic async completion returns a valid response with tokens.\"\"\"\n    response = await litellm.acompletion(\n        model=\"tinker/test\",\n        messages=[{\"role\": \"user\", \"content\": \"What is 2+2? Answer with just the number.\"}],\n        base_model=BASE_MODEL,\n        temperature=0.0,\n        max_tokens=32,\n    )\n\n    assert len(response.choices) == 1\n    choice = response.choices[0]\n    assert choice.message.content is not None\n    assert len(choice.message.content) > 0\n    assert choice.finish_reason in (\"stop\", \"length\")\n\n    # Verify raw tokens are accessible\n    fields = choice.message.provider_specific_fields\n    assert fields is not None\n    assert isinstance(fields[\"prompt_token_ids\"], list)\n    assert isinstance(fields[\"completion_token_ids\"], list)\n    assert len(fields[\"prompt_token_ids\"]) > 0\n    assert len(fields[\"completion_token_ids\"]) > 0\n\n    # Usage should be populated\n    assert response.usage.prompt_tokens > 0\n    assert response.usage.completion_tokens > 0\n\n\n@pytest.mark.integration\n@pytest.mark.asyncio\nasync def test_acompletion_with_system_message(tinker_provider) -> None:\n    \"\"\"System messages are handled correctly.\"\"\"\n    response = await litellm.acompletion(\n        model=\"tinker/test\",\n        messages=[\n            {\"role\": \"system\", \"content\": \"You are a helpful assistant. Be concise.\"},\n            {\"role\": \"user\", \"content\": \"Say hello.\"},\n        ],\n        base_model=BASE_MODEL,\n        temperature=0.0,\n        max_tokens=32,\n    )\n\n    assert response.choices[0].message.content is not None\n\n\n@pytest.mark.integration\n@pytest.mark.asyncio\nasync def test_acompletion_multi_turn(tinker_provider) -> None:\n    \"\"\"Multi-turn conversations work.\"\"\"\n    response = await litellm.acompletion(\n        model=\"tinker/test\",\n        messages=[\n            {\"role\": \"user\", \"content\": \"My name is Alice.\"},\n            {\"role\": \"assistant\", \"content\": \"Hello Alice! How can I help you?\"},\n            {\"role\": \"user\", \"content\": \"What is my name?\"},\n        ],\n        base_model=BASE_MODEL,\n        temperature=0.0,\n        max_tokens=32,\n    )\n\n    assert response.choices[0].message.content is not None\n\n\n@pytest.mark.integration\ndef test_completion_sync(tinker_provider) -> None:\n    \"\"\"Sync completion also works.\"\"\"\n    response = litellm.completion(\n        model=\"tinker/test\",\n        messages=[{\"role\": \"user\", \"content\": \"Say hi.\"}],\n        base_model=BASE_MODEL,\n        temperature=0.0,\n        max_tokens=16,\n    )\n\n    assert response.choices[0].message.content is not None\n    fields = response.choices[0].message.provider_specific_fields\n    assert fields is not None\n    assert len(fields[\"completion_token_ids\"]) > 0\n\n\n# ---------------------------------------------------------------------------\n# Fine-tuned checkpoint test\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.integration\n@pytest.mark.asyncio\nasync def test_set_client_with_finetuned_checkpoint(tinker_provider) -> None:\n    \"\"\"set_client() with a checkpoint-based SamplingClient works end-to-end.\n\n    Creates a LoRA training client, saves weights immediately (untrained),\n    and samples through LiteLLM via set_client().\n    \"\"\"\n    service = tinker.ServiceClient()\n    training_client = service.create_lora_training_client(base_model=BASE_MODEL, rank=8)\n\n    # Save weights and get a sampling client for the checkpoint\n    checkpoint_sampler = training_client.save_weights_and_get_sampling_client(name=\"litellm_test\")\n\n    # Inject via set_client — base_model is derived from the sampling client\n    tinker_provider.set_client(checkpoint_sampler)\n\n    response = await litellm.acompletion(\n        model=\"tinker/finetuned-test\",\n        messages=[{\"role\": \"user\", \"content\": \"Say hello.\"}],\n        base_model=BASE_MODEL,\n        temperature=0.0,\n        max_tokens=32,\n    )\n\n    assert response.choices[0].message.content is not None\n    fields = response.choices[0].message.provider_specific_fields\n    assert fields is not None\n    assert len(fields[\"prompt_token_ids\"]) > 0\n    assert len(fields[\"completion_token_ids\"]) > 0\n\n\n# ---------------------------------------------------------------------------\n# Idempotent registration test\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.integration\ndef test_register_idempotent(tinker_provider) -> None:\n    \"\"\"Calling register_litellm_provider() again returns the same instance.\"\"\"\n    provider2 = register_litellm_provider()\n    assert provider2 is tinker_provider\n"
  },
  {
    "path": "tests/validate_temperature_logprobs.py",
    "content": "\"\"\"\nValidate temperature scaling in sampling by comparing pairwise logprob differences.\n\nTwo complementary checks ensure correctness across temperatures and sequence positions:\n1. Temperature scaling: Verifies (log p_τ(i) - log p_τ(j)) ≈ (1/τ) * (log p_1(i) - log p_1(j))\n2. Sequence-level consistency: Validates multi-token sampling returns accurate logprobs at each step.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nfrom collections.abc import Sequence\n\nimport chz\nimport numpy as np\nimport tinker\n\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\ndef _default_temperatures() -> list[float]:\n    return [0.5, 0.7, 1.0, 1.2, 1.5, 1.8]\n\n\n@chz.chz\nclass Config:\n    base_model: str\n    prompt: str = (\n        \"Explain temperature scaling in language model sampling, include a brief \"\n        \"example, and discuss calibration vs diversity trade-offs.\"\n    )\n    temperatures: list[float] = chz.field(default_factory=_default_temperatures)\n    baseline_temperature: float = 1.0\n    num_trials: int = 20\n    check_sequence_consistency: bool = True\n    consistency_check_length: int = 20\n    consistency_check_temp: float = 0.5\n    seed: int | None = 42\n    base_url: str | None = None\n\n\nasync def _sample_next_token(\n    sampling_client: tinker.SamplingClient,\n    model_input: tinker.ModelInput,\n    *,\n    temperature: float,\n    max_tokens: int,\n    seed: int | None,\n) -> tuple[list[int], list[float]]:\n    resp = await sampling_client.sample_async(\n        prompt=model_input,\n        num_samples=1,\n        sampling_params=tinker.SamplingParams(\n            max_tokens=max_tokens,\n            temperature=temperature,\n            seed=seed,\n        ),\n    )\n    seq = resp.sequences[0]\n    if seq.logprobs is None:\n        raise RuntimeError(\"Sampling response did not include logprobs\")\n    return seq.tokens, seq.logprobs\n\n\nasync def _collect_sampled_token_logprobs(\n    sampling_client: tinker.SamplingClient,\n    model_input: tinker.ModelInput,\n    *,\n    temperature: float,\n    num_trials: int,\n    max_tokens: int,\n    seed: int | None,\n) -> dict[int, float]:\n    \"\"\"Collect token_id -> logprob at a given temperature over several trials.\"\"\"\n    out: dict[int, float] = {}\n    base = 0 if seed is None else seed\n    for i in range(num_trials):\n        s = base + i if seed is not None else None\n        tokens, lps = await _sample_next_token(\n            sampling_client,\n            model_input,\n            temperature=temperature,\n            max_tokens=max_tokens,\n            seed=s,\n        )\n        if not tokens:\n            continue\n        t = tokens[0]\n        out.setdefault(t, lps[0])\n    return out\n\n\nasync def _compute_logp1_for_tokens(\n    sampling_client: tinker.SamplingClient,\n    prompt_tokens: list[int],\n    tokens: Sequence[int],\n) -> dict[int, float]:\n    \"\"\"Compute baseline log p_1(token|prompt) for each token via compute_logprobs_async.\"\"\"\n    res: dict[int, float] = {}\n    for tok in tokens:\n        seq = tinker.ModelInput.from_ints(prompt_tokens + [tok])\n        lps = await sampling_client.compute_logprobs_async(seq)\n        lp = lps[len(prompt_tokens)]\n        if lp is None:\n            raise RuntimeError(\n                \"compute_logprobs_async did not return a logprob for the sampled token\"\n            )\n        res[tok] = lp\n    return res\n\n\ndef _pairwise_ratio_metrics(\n    base_logp: dict[int, float],\n    temp_logp: dict[int, float],\n    temperature: float,\n) -> dict[str, float]:\n    \"\"\"Compare pairwise logprob differences: (log p_τ(i) - log p_τ(j)) vs (1/τ) * (log p_1(i) - log p_1(j)).\"\"\"\n    common = sorted(set(base_logp) & set(temp_logp))\n    if len(common) < 2:\n        return {\n            \"tokens\": float(len(common)),\n            \"pairs\": 0.0,\n            \"mean_abs_err\": float(\"nan\"),\n            \"max_abs_err\": float(\"nan\"),\n        }\n    base_diffs: list[float] = []\n    temp_diffs: list[float] = []\n    inv_tau = 1.0 / max(temperature, 1e-9)\n    for a in range(len(common)):\n        for b in range(a + 1, len(common)):\n            i, j = common[a], common[b]\n            base_diffs.append(inv_tau * (base_logp[i] - base_logp[j]))\n            temp_diffs.append(temp_logp[i] - temp_logp[j])\n    x = np.array(base_diffs, dtype=float)\n    y = np.array(temp_diffs, dtype=float)\n    abs_err = np.abs(y - x)\n    mean_abs_err = float(np.mean(abs_err))\n    max_abs_err = float(np.max(abs_err))\n    return {\n        \"tokens\": float(len(common)),\n        \"pairs\": float(len(base_diffs)),\n        \"mean_abs_err\": mean_abs_err,\n        \"max_abs_err\": max_abs_err,\n    }\n\n\n# ============================================================================\n# Sequence-level consistency validation\n# ============================================================================\n\n\nasync def _sample_sequence_oneshot(\n    sampling_client: tinker.SamplingClient,\n    prompt_tokens: list[int],\n    *,\n    temperature: float,\n    max_tokens: int,\n    seed: int | None,\n) -> tuple[list[int], list[float]]:\n    \"\"\"Sample a sequence in one call with max_tokens > 1.\"\"\"\n    model_input = tinker.ModelInput.from_ints(prompt_tokens)\n    resp = await sampling_client.sample_async(\n        prompt=model_input,\n        num_samples=1,\n        sampling_params=tinker.SamplingParams(\n            max_tokens=max_tokens,\n            temperature=temperature,\n            seed=seed,\n        ),\n    )\n    seq = resp.sequences[0]\n    if seq.logprobs is None:\n        raise RuntimeError(\"Sampling response did not include logprobs\")\n    return seq.tokens, seq.logprobs\n\n\nasync def _resample_tokens_individually(\n    sampling_client: tinker.SamplingClient,\n    prompt_tokens: list[int],\n    *,\n    temperature: float,\n    length: int,\n    seed: int | None,\n) -> tuple[list[int], list[float]]:\n    \"\"\"Sample tokens one at a time, feeding each back into the prefix.\n\n    This mimics what max_tokens > 1 should do internally: sample token i,\n    append to context, then sample token i+1.\n    \"\"\"\n    tokens: list[int] = []\n    logprobs: list[float] = []\n    current_prefix = prompt_tokens.copy()\n\n    for i in range(length):\n        model_input = tinker.ModelInput.from_ints(current_prefix)\n        # Increment seed for each position to get different random states\n        pos_seed = (seed + i) if seed is not None else None\n\n        resp = await sampling_client.sample_async(\n            prompt=model_input,\n            num_samples=1,\n            sampling_params=tinker.SamplingParams(\n                max_tokens=1,\n                temperature=temperature,\n                seed=pos_seed,\n            ),\n        )\n        seq = resp.sequences[0]\n        if not seq.tokens or seq.logprobs is None:\n            break\n\n        tok = seq.tokens[0]\n        logprob = seq.logprobs[0]\n        tokens.append(tok)\n        logprobs.append(logprob)\n        current_prefix.append(tok)\n\n    return tokens, logprobs\n\n\ndef _compare_logprobs(\n    sampled_logprobs: list[float],\n    computed_logprobs: list[float],\n) -> dict[str, float]:\n    \"\"\"Compare sampled vs recomputed logprobs.\"\"\"\n    min_len = min(len(sampled_logprobs), len(computed_logprobs))\n    if min_len == 0:\n        return {\n            \"length\": 0.0,\n            \"mean_diff\": float(\"nan\"),\n            \"max_diff\": float(\"nan\"),\n        }\n\n    diffs = [abs(sampled_logprobs[i] - computed_logprobs[i]) for i in range(min_len)]\n\n    return {\n        \"length\": float(min_len),\n        \"mean_diff\": float(np.mean(diffs)),\n        \"max_diff\": float(np.max(diffs)),\n    }\n\n\nasync def validate_sequence_consistency(\n    sampling_client: tinker.SamplingClient,\n    prompt_tokens: list[int],\n    *,\n    temperature: float,\n    length: int,\n    seed: int | None,\n    tokenizer,\n) -> None:\n    \"\"\"Validate that sample_async(max_tokens > 1) returns accurate per-step logprobs.\n\n    Generates a sequence then resamples each position individually to find matching tokens\n    and compare their logprobs, validating correctness at each step.\n    \"\"\"\n    print(\"\\n\" + \"=\" * 75)\n    print(\"SEQUENCE-LEVEL CONSISTENCY CHECK (multi-token logprob validation)\")\n    print(\"=\" * 75)\n    print(\n        f\"Generate with max_tokens={length} at temp={temperature}, then resample each position individually to verify logprob consistency.\"\n    )\n    print(f\"{'Temp':>8}  {'Length':>8}  {'Matches':>8}  {'Mean Diff':>12}  {'Max Diff':>12}\")\n    print(\"-\" * 75)\n\n    tau = temperature\n    gen_tokens, gen_logprobs = await _sample_sequence_oneshot(\n        sampling_client, prompt_tokens, temperature=tau, max_tokens=length, seed=seed\n    )\n\n    matching_diffs: list[float] = []\n    num_attempts_per_position = 5\n\n    for i in range(len(gen_tokens)):\n        prefix = prompt_tokens + gen_tokens[:i]\n        model_input = tinker.ModelInput.from_ints(prefix)\n\n        for attempt in range(num_attempts_per_position):\n            resp = await sampling_client.sample_async(\n                prompt=model_input,\n                num_samples=1,\n                sampling_params=tinker.SamplingParams(\n                    max_tokens=1,\n                    temperature=tau,\n                    seed=(seed + 1000 * (i + 1) + attempt) if seed is not None else None,\n                ),\n            )\n            seq = resp.sequences[0]\n            if not seq.tokens or seq.logprobs is None:\n                continue\n\n            if seq.tokens[0] == gen_tokens[i]:\n                matching_diffs.append(abs(gen_logprobs[i] - seq.logprobs[0]))\n                break\n\n    if len(matching_diffs) == 0:\n        print(f\"{tau:>8.3f}  {len(gen_tokens):>8}  {0:>8}  {'N/A':>12}  {'N/A':>12}  {'N/A':>8}\")\n        return\n\n    mean_diff = float(np.mean(matching_diffs))\n    max_diff = float(np.max(matching_diffs))\n    print(\n        f\"{tau:>8.3f}  {len(gen_tokens):>8}  {len(matching_diffs):>8}  {mean_diff:>12.6f}  {max_diff:>12.6f}\"\n    )\n    print()\n\n\nasync def main(cfg: Config) -> None:\n    tokenizer = get_tokenizer(cfg.base_model)\n    prompt_tokens = tokenizer.encode(cfg.prompt)\n    model_input = tinker.ModelInput.from_ints(prompt_tokens)\n\n    service = tinker.ServiceClient(base_url=cfg.base_url)\n    sampler = service.create_sampling_client(base_model=cfg.base_model)\n\n    print(\"\\n\" + \"=\" * 75)\n    print(\"TEMPERATURE SCALING VALIDATION\")\n    print(\"=\" * 75)\n\n    base_seen = await _collect_sampled_token_logprobs(\n        sampler,\n        model_input,\n        temperature=cfg.baseline_temperature,\n        num_trials=cfg.num_trials,\n        max_tokens=1,\n        seed=cfg.seed,\n    )\n    base_logp = await _compute_logp1_for_tokens(sampler, prompt_tokens, list(base_seen))\n\n    print(f\"Model: {cfg.base_model}, {cfg.num_trials} trials per temperature\")\n    print(f\"{'Temp':>8}  {'Unique Tokens':>15}  {'Pairs':>8}  {'Mean Diff':>12}  {'Max Diff':>12}\")\n    print(\"-\" * 75)\n\n    for tau in cfg.temperatures:\n        temp_seen = await _collect_sampled_token_logprobs(\n            sampler,\n            model_input,\n            temperature=tau,\n            num_trials=cfg.num_trials,\n            max_tokens=1,\n            seed=cfg.seed,\n        )\n        missing = [t for t in temp_seen if t not in base_logp]\n        if missing:\n            base_logp.update(await _compute_logp1_for_tokens(sampler, prompt_tokens, missing))\n        metrics = _pairwise_ratio_metrics(base_logp, temp_seen, tau)\n\n        mean_diff = metrics[\"mean_abs_err\"]\n        max_diff = metrics[\"max_abs_err\"]\n        print(\n            f\"{tau:>8.3f}  {int(metrics['tokens']):>15}  {int(metrics['pairs']):>8}  {mean_diff:>12.6f}  {max_diff:>12.6f}\"\n        )\n\n    if cfg.check_sequence_consistency:\n        await validate_sequence_consistency(\n            sampler,\n            prompt_tokens,\n            temperature=cfg.consistency_check_temp,\n            length=cfg.consistency_check_length,\n            seed=cfg.seed,\n            tokenizer=tokenizer,\n        )\n\n    print()\n\n\nif __name__ == \"__main__\":\n    asyncio.run(chz.nested_entrypoint(main))\n"
  },
  {
    "path": "tests/weights/__init__.py",
    "content": ""
  },
  {
    "path": "tests/weights/test_download.py",
    "content": "\"\"\"Integration test for weights.download().\n\nRequires TINKER_API_KEY to be set. Skipped otherwise.\n\"\"\"\n\nimport os\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom tinker_cookbook.weights import download\n\n\n@pytest.mark.integration\nclass TestDownloadIntegration:\n    \"\"\"Download a real adapter from Tinker and verify the extracted files.\"\"\"\n\n    def _get_test_tinker_path(self) -> str:\n        \"\"\"Return a known tinker checkpoint path for testing.\n\n        Uses the smoke test checkpoint if available via env var,\n        otherwise skips.\n        \"\"\"\n        path = os.environ.get(\"TINKER_TEST_CHECKPOINT_PATH\")\n        if not path:\n            pytest.skip(\n                \"Set TINKER_TEST_CHECKPOINT_PATH to a valid tinker:// path to run this test\"\n            )\n        return path\n\n    def test_download_and_extract(self):\n        tinker_path = self._get_test_tinker_path()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            output_dir = str(Path(tmpdir) / \"adapter\")\n\n            result = download(tinker_path=tinker_path, output_dir=output_dir)\n\n            assert result == output_dir\n            out = Path(output_dir)\n            assert out.is_dir(), f\"Output directory not created: {output_dir}\"\n\n            # Verify at least one file was extracted\n            files = list(out.rglob(\"*\"))\n            assert len(files) > 0, \"No files extracted from archive\"\n\n            # If it's a LoRA adapter, check for expected files\n            if (out / \"adapter_model.safetensors\").exists():\n                assert (out / \"adapter_config.json\").exists(), (\n                    \"adapter_model.safetensors found but adapter_config.json missing\"\n                )\n"
  },
  {
    "path": "tests/weights/test_export.py",
    "content": "\"\"\"End-to-end tests for build_hf_model across all supported model families.\n\nEach test instantiates a tiny real HuggingFace model from config (no weight\ndownload), saves it to disk with synthetic LoRA adapter weights, runs the\nfull build_hf_model pipeline, reloads, and verifies correctness.\n\nModel families tested:\n- GPT-OSS: fused interleaved gate_up_proj\n- Qwen3-VL MoE: fused concatenated gate_up_proj + vision model prefix\n- Qwen3 MoE: separate per-expert weights\n- DeepSeek V3.1: separate per-expert weights + FP8 quantized export\n- Qwen3 dense: standard linear layers (no experts)\n\"\"\"\n\nimport json\nimport shutil\nimport tempfile\nfrom pathlib import Path\n\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom safetensors.torch import load_file, save_file\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoModelForImageTextToText,\n    AutoTokenizer,\n    PretrainedConfig,\n)\n\nfrom tinker_cookbook.weights import build_hf_model\n\n# ---------------------------------------------------------------------------\n# Shared helpers\n# ---------------------------------------------------------------------------\n\nFILL_A = 0.01  # LoRA fill for gate / first projection\nFILL_B = 0.05  # LoRA fill for up / second projection\n\n\ndef _save_model_to_disk(\n    config: PretrainedConfig,\n    path: Path,\n    *,\n    tokenizer_name: str,\n    is_vision: bool = False,\n) -> None:\n    auto_cls = AutoModelForImageTextToText if is_vision else AutoModelForCausalLM\n    model = auto_cls.from_config(config, trust_remote_code=True, dtype=torch.float32)\n    model.save_pretrained(path)\n    tok = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)\n    tok.save_pretrained(path)\n\n\ndef _save_expert_adapter(\n    path: Path,\n    *,\n    num_experts: int,\n    in_dim: int,\n    out_dim: int,\n    gate_fill: float = FILL_A,\n    up_fill: float = FILL_B,\n    layer_prefix: str = \"base_model.model.model.layers.0.mlp.experts\",\n) -> None:\n    \"\"\"Save a LoRA adapter for expert gate (w1) and up (w3) projections.\"\"\"\n    weights: dict[str, torch.Tensor] = {}\n    rank = 1\n    weights[f\"{layer_prefix}.w1.lora_A.weight\"] = torch.ones(num_experts, rank, in_dim) * gate_fill\n    weights[f\"{layer_prefix}.w1.lora_B.weight\"] = torch.ones(num_experts, out_dim, rank)\n    weights[f\"{layer_prefix}.w3.lora_A.weight\"] = torch.ones(num_experts, rank, in_dim) * up_fill\n    weights[f\"{layer_prefix}.w3.lora_B.weight\"] = torch.ones(num_experts, out_dim, rank)\n\n    path.mkdir(parents=True)\n    save_file(weights, str(path / \"adapter_model.safetensors\"))\n    (path / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": rank}))\n\n\ndef _save_dense_adapter(\n    path: Path,\n    *,\n    in_dim: int,\n    out_dim: int,\n    fill: float = FILL_A,\n    layer_prefix: str = \"base_model.model.model.layers.0.mlp\",\n) -> None:\n    \"\"\"Save a LoRA adapter for a dense (non-expert) linear layer.\"\"\"\n    rank = 1\n    weights = {\n        f\"{layer_prefix}.gate_proj.lora_A.weight\": torch.ones(rank, in_dim) * fill,\n        f\"{layer_prefix}.gate_proj.lora_B.weight\": torch.ones(out_dim, rank),\n    }\n\n    path.mkdir(parents=True)\n    save_file(weights, str(path / \"adapter_model.safetensors\"))\n    (path / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": rank}))\n\n\ndef _run_build_and_reload(\n    model_path: Path,\n    adapter_path: Path,\n    output_path: Path,\n    *,\n    is_vision: bool = False,\n) -> dict[str, torch.Tensor]:\n    \"\"\"Run build_hf_model and return the reloaded state dict.\"\"\"\n    build_hf_model(\n        base_model=str(model_path),\n        adapter_path=str(adapter_path),\n        output_path=str(output_path),\n    )\n    auto_cls = AutoModelForImageTextToText if is_vision else AutoModelForCausalLM\n    reloaded = auto_cls.from_pretrained(output_path, trust_remote_code=True, dtype=torch.float32)\n    return reloaded.state_dict()\n\n\n# ---------------------------------------------------------------------------\n# 1. GPT-OSS — fused interleaved gate_up_proj\n# ---------------------------------------------------------------------------\n\n\ndef _make_tiny_gpt_oss_config() -> PretrainedConfig:\n    config = AutoConfig.from_pretrained(\"openai/gpt-oss-20b\", trust_remote_code=True)\n    config.num_hidden_layers = 1\n    config.num_local_experts = 2\n    config.hidden_size = 64\n    config.intermediate_size = 64\n    config.num_attention_heads = 2\n    config.num_key_value_heads = 2\n    config.layer_types = [\"full_attention\"]\n    if hasattr(config, \"quantization_config\"):\n        delattr(config, \"quantization_config\")\n    return config\n\n\nclass TestGptOssFusedInterleaved:\n    \"\"\"GPT-OSS: gate_up_proj with interleaved layout [g0, u0, g1, u1, ...].\"\"\"\n\n    FUSED_KEY = \"model.layers.0.mlp.experts.gate_up_proj\"\n\n    def test_gate_and_up_deltas_in_correct_interleaved_slots(self):\n        config = _make_tiny_gpt_oss_config()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            model_path, adapter_path, output_path = (\n                root / \"model\",\n                root / \"adapter\",\n                root / \"merged\",\n            )\n\n            _save_model_to_disk(config, model_path, tokenizer_name=\"openai/gpt-oss-20b\")\n            orig = AutoModelForCausalLM.from_pretrained(\n                model_path, trust_remote_code=True, dtype=torch.float32\n            )\n            orig_fused = orig.state_dict()[self.FUSED_KEY].clone()\n            num_experts, in_dim, fused_dim = orig_fused.shape\n\n            _save_expert_adapter(\n                adapter_path, num_experts=num_experts, in_dim=in_dim, out_dim=fused_dim // 2\n            )\n            merged_sd = _run_build_and_reload(model_path, adapter_path, output_path)\n\n            delta = merged_sd[self.FUSED_KEY] - orig_fused\n            gate_delta = delta[:, :, 0::2]\n            up_delta = delta[:, :, 1::2]\n\n            assert torch.allclose(gate_delta, torch.full_like(gate_delta, FILL_A), atol=1e-3)\n            assert torch.allclose(up_delta, torch.full_like(up_delta, FILL_B), atol=1e-3)\n\n    def test_up_only_does_not_modify_gate_slots(self):\n        config = _make_tiny_gpt_oss_config()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            model_path, adapter_path, output_path = (\n                root / \"model\",\n                root / \"adapter\",\n                root / \"merged\",\n            )\n\n            _save_model_to_disk(config, model_path, tokenizer_name=\"openai/gpt-oss-20b\")\n            orig = AutoModelForCausalLM.from_pretrained(\n                model_path, trust_remote_code=True, dtype=torch.float32\n            )\n            orig_gate = orig.state_dict()[self.FUSED_KEY][:, :, 0::2].clone()\n            num_experts, in_dim, fused_dim = orig.state_dict()[self.FUSED_KEY].shape\n\n            # Save only w3 (up) adapter\n            prefix = \"base_model.model.model.layers.0.mlp.experts\"\n            rank = 1\n            up_only = {\n                f\"{prefix}.w3.lora_A.weight\": torch.ones(num_experts, rank, in_dim) * FILL_B,\n                f\"{prefix}.w3.lora_B.weight\": torch.ones(num_experts, fused_dim // 2, rank),\n            }\n            adapter_path.mkdir(parents=True)\n            save_file(up_only, str(adapter_path / \"adapter_model.safetensors\"))\n            (adapter_path / \"adapter_config.json\").write_text(\n                json.dumps({\"lora_alpha\": 1, \"r\": rank})\n            )\n\n            merged_sd = _run_build_and_reload(model_path, adapter_path, output_path)\n            merged_gate = merged_sd[self.FUSED_KEY][:, :, 0::2]\n\n            assert torch.allclose(merged_gate, orig_gate, atol=1e-3), (\n                \"up adapter modified gate slots\"\n            )\n\n\n# ---------------------------------------------------------------------------\n# 2. Qwen3-VL MoE — fused concatenated gate_up_proj + vision prefix\n# ---------------------------------------------------------------------------\n\n\ndef _make_tiny_qwen3_vl_moe_config() -> PretrainedConfig:\n    config = AutoConfig.from_pretrained(\"Qwen/Qwen3-VL-30B-A3B-Instruct\", trust_remote_code=True)\n    tc = config.text_config\n    tc.num_hidden_layers = 1\n    tc.num_experts = 2\n    tc.num_experts_per_tok = 1\n    tc.hidden_size = 64\n    tc.intermediate_size = 64\n    tc.num_attention_heads = 2\n    tc.num_key_value_heads = 2\n    config.vision_config.num_hidden_layers = 1\n    config.vision_config.hidden_size = 64\n    config.vision_config.intermediate_size = 64\n    config.vision_config.num_attention_heads = 2\n    return config\n\n\nclass TestQwen3VlMoeFusedConcatenated:\n    \"\"\"Qwen3-VL MoE: gate_up_proj with concatenated layout [gate | up].\n\n    Also tests the vision model language_model prefix remapping.\n    \"\"\"\n\n    FUSED_KEY = \"model.language_model.layers.0.mlp.experts.gate_up_proj\"\n\n    def test_gate_and_up_deltas_in_correct_halves(self):\n        config = _make_tiny_qwen3_vl_moe_config()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            model_path, adapter_path, output_path = (\n                root / \"model\",\n                root / \"adapter\",\n                root / \"merged\",\n            )\n\n            _save_model_to_disk(\n                config,\n                model_path,\n                tokenizer_name=\"Qwen/Qwen3-VL-30B-A3B-Instruct\",\n                is_vision=True,\n            )\n            orig = AutoModelForImageTextToText.from_pretrained(\n                model_path, trust_remote_code=True, dtype=torch.float32\n            )\n            orig_fused = orig.state_dict()[self.FUSED_KEY].clone()\n            num_experts, in_dim, fused_dim = orig_fused.shape\n            sz = fused_dim // 2\n\n            # Vision model: adapter uses model.layers... but HF has model.language_model.layers...\n            _save_expert_adapter(\n                adapter_path,\n                num_experts=num_experts,\n                in_dim=in_dim,\n                out_dim=sz,\n            )\n\n            merged_sd = _run_build_and_reload(model_path, adapter_path, output_path, is_vision=True)\n\n            delta = merged_sd[self.FUSED_KEY] - orig_fused\n            gate_half = delta[:, :, :sz]\n            up_half = delta[:, :, sz:]\n\n            assert torch.allclose(gate_half, torch.full_like(gate_half, FILL_A), atol=1e-3)\n            assert torch.allclose(up_half, torch.full_like(up_half, FILL_B), atol=1e-3)\n\n    def test_up_only_does_not_modify_gate_half(self):\n        config = _make_tiny_qwen3_vl_moe_config()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            model_path, adapter_path, output_path = (\n                root / \"model\",\n                root / \"adapter\",\n                root / \"merged\",\n            )\n\n            _save_model_to_disk(\n                config,\n                model_path,\n                tokenizer_name=\"Qwen/Qwen3-VL-30B-A3B-Instruct\",\n                is_vision=True,\n            )\n            orig = AutoModelForImageTextToText.from_pretrained(\n                model_path, trust_remote_code=True, dtype=torch.float32\n            )\n            orig_fused = orig.state_dict()[self.FUSED_KEY].clone()\n            num_experts, in_dim, fused_dim = orig_fused.shape\n            sz = fused_dim // 2\n\n            # Only w3 (up) adapter\n            prefix = \"base_model.model.model.layers.0.mlp.experts\"\n            rank = 1\n            weights = {\n                f\"{prefix}.w3.lora_A.weight\": torch.ones(num_experts, rank, in_dim) * FILL_B,\n                f\"{prefix}.w3.lora_B.weight\": torch.ones(num_experts, sz, rank),\n            }\n            adapter_path.mkdir(parents=True)\n            save_file(weights, str(adapter_path / \"adapter_model.safetensors\"))\n            (adapter_path / \"adapter_config.json\").write_text(\n                json.dumps({\"lora_alpha\": 1, \"r\": rank})\n            )\n\n            merged_sd = _run_build_and_reload(model_path, adapter_path, output_path, is_vision=True)\n\n            orig_gate = orig_fused[:, :, :sz]\n            merged_gate = merged_sd[self.FUSED_KEY][:, :, :sz]\n\n            assert torch.allclose(merged_gate, orig_gate, atol=1e-3), (\n                \"up adapter modified gate half\"\n            )\n\n\n# ---------------------------------------------------------------------------\n# 3. Qwen3 MoE — separate per-expert weights\n# ---------------------------------------------------------------------------\n\n\ndef _make_tiny_qwen3_moe_config() -> PretrainedConfig:\n    config = AutoConfig.from_pretrained(\"Qwen/Qwen3-30B-A3B\", trust_remote_code=True)\n    config.num_hidden_layers = 1\n    config.num_experts = 2\n    config.num_experts_per_tok = 1\n    config.hidden_size = 64\n    config.intermediate_size = 64\n    config.num_attention_heads = 2\n    config.num_key_value_heads = 2\n    return config\n\n\nclass TestQwen3MoeSeparateExperts:\n    \"\"\"Qwen3 MoE: individual gate_proj/up_proj per expert.\"\"\"\n\n    def test_per_expert_weights_updated(self):\n        config = _make_tiny_qwen3_moe_config()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            model_path, adapter_path, output_path = (\n                root / \"model\",\n                root / \"adapter\",\n                root / \"merged\",\n            )\n\n            _save_model_to_disk(config, model_path, tokenizer_name=\"Qwen/Qwen3-30B-A3B\")\n            orig = AutoModelForCausalLM.from_pretrained(\n                model_path, trust_remote_code=True, dtype=torch.float32\n            )\n            orig_sd = {k: v.clone() for k, v in orig.state_dict().items()}\n            num_experts = 2\n\n            # Read actual dims from model (gate_proj shape is [intermediate, hidden])\n            gate_shape = orig_sd[\"model.layers.0.mlp.experts.0.gate_proj.weight\"].shape\n            expert_out_dim, expert_in_dim = gate_shape\n            _save_expert_adapter(\n                adapter_path, num_experts=num_experts, in_dim=expert_in_dim, out_dim=expert_out_dim\n            )\n            merged_sd = _run_build_and_reload(model_path, adapter_path, output_path)\n\n            for i in range(num_experts):\n                gate_key = f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\"\n                up_key = f\"model.layers.0.mlp.experts.{i}.up_proj.weight\"\n\n                gate_delta = (merged_sd[gate_key] - orig_sd[gate_key]).abs().sum()\n                up_delta = (merged_sd[up_key] - orig_sd[up_key]).abs().sum()\n\n                assert gate_delta > 0, f\"Expert {i} gate_proj not updated\"\n                assert up_delta > 0, f\"Expert {i} up_proj not updated\"\n\n\n# ---------------------------------------------------------------------------\n# 4. DeepSeek V3.1 — separate per-expert weights + FP8 quantized export\n# ---------------------------------------------------------------------------\n\n\ndef _make_tiny_deepseek_v31_config() -> PretrainedConfig:\n    config = AutoConfig.from_pretrained(\"deepseek-ai/DeepSeek-V3.1\", trust_remote_code=True)\n    config.num_hidden_layers = 1\n    config.hidden_size = 64\n    config.intermediate_size = 64\n    config.moe_intermediate_size = 16\n    config.num_attention_heads = 2\n    config.num_key_value_heads = 2\n    config.n_routed_experts = 2\n    config.n_shared_experts = 1\n    config.num_experts_per_tok = 1\n    config.first_k_dense_replace = 0\n    config.vocab_size = 256\n    if hasattr(config, \"quantization_config\"):\n        delattr(config, \"quantization_config\")\n    return config\n\n\ndef _copy_hf_files(repo_id: str, output_path: Path, file_names: tuple[str, ...]) -> None:\n    \"\"\"Download specific files from a HF repo and copy to output_path.\"\"\"\n    snapshot_path = Path(snapshot_download(repo_id=repo_id, allow_patterns=list(file_names)))\n    for file_name in file_names:\n        shutil.copy2(snapshot_path / file_name, output_path / file_name)\n\n\ndef _save_mixed_deepseek_adapter(\n    path: Path,\n    *,\n    num_experts: int,\n    expert_in_dim: int,\n    expert_out_dim: int,\n    dense_in_dim: int,\n    dense_out_dim: int,\n    dense_fill: float = FILL_A,\n    gate_fill: float = FILL_A,\n    up_fill: float = FILL_B,\n) -> None:\n    \"\"\"Save a DeepSeek adapter with both dense and routed-expert LoRA weights.\"\"\"\n    rank = 1\n    weights: dict[str, torch.Tensor] = {\n        \"base_model.model.model.layers.0.self_attn.q_a_proj.lora_A.weight\": (\n            torch.ones(rank, dense_in_dim, dtype=torch.bfloat16) * dense_fill\n        ),\n        \"base_model.model.model.layers.0.self_attn.q_a_proj.lora_B.weight\": torch.ones(\n            dense_out_dim, rank, dtype=torch.bfloat16\n        ),\n        \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": (\n            torch.ones(num_experts, rank, expert_in_dim, dtype=torch.bfloat16) * gate_fill\n        ),\n        \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": torch.ones(\n            num_experts, expert_out_dim, rank, dtype=torch.bfloat16\n        ),\n        \"base_model.model.model.layers.0.mlp.experts.w3.lora_A.weight\": (\n            torch.ones(num_experts, rank, expert_in_dim, dtype=torch.bfloat16) * up_fill\n        ),\n        \"base_model.model.model.layers.0.mlp.experts.w3.lora_B.weight\": torch.ones(\n            num_experts, expert_out_dim, rank, dtype=torch.bfloat16\n        ),\n    }\n\n    path.mkdir(parents=True)\n    save_file(weights, str(path / \"adapter_model.safetensors\"))\n    (path / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": rank}))\n\n\ndef _reshard_saved_model(\n    model_path: Path,\n    *,\n    shard_assignments: dict[str, str],\n    default_shard: str = \"model-00002-of-00002.safetensors\",\n) -> dict[str, str]:\n    \"\"\"Rewrite a local checkpoint into a small sharded layout with an HF index.\"\"\"\n    source_path = model_path / \"model.safetensors\"\n    state_dict = load_file(str(source_path))\n    shard_state_dicts: dict[str, dict[str, torch.Tensor]] = {}\n    weight_map: dict[str, str] = {}\n\n    for key, tensor in state_dict.items():\n        shard_name = shard_assignments.get(key, default_shard)\n        shard_state_dicts.setdefault(shard_name, {})[key] = tensor\n        weight_map[key] = shard_name\n\n    source_path.unlink()\n    for shard_name, shard_sd in sorted(shard_state_dicts.items()):\n        save_file(shard_sd, str(model_path / shard_name))\n\n    total_size = sum(t.nelement() * t.element_size() for t in state_dict.values())\n    index = {\"metadata\": {\"total_size\": total_size}, \"weight_map\": weight_map}\n    (model_path / \"model.safetensors.index.json\").write_text(json.dumps(index, indent=2))\n    return weight_map\n\n\ndef _load_saved_state_dict(output_path: Path) -> dict[str, torch.Tensor]:\n    \"\"\"Load tensors exactly as written to disk, preserving saved dtypes.\"\"\"\n    state_dict: dict[str, torch.Tensor] = {}\n    for safetensors_path in sorted(output_path.glob(\"*.safetensors\")):\n        state_dict.update(load_file(str(safetensors_path)))\n    return state_dict\n\n\nclass TestDeepSeekV31FP8Export:\n    \"\"\"DeepSeek V3.1: dense weights stay BF16 while routed experts are quantized to FP8.\n\n    Uses real DeepSeek config from HF (downloads config + custom code, not weights).\n    \"\"\"\n\n    def test_dense_weights_change_but_only_routed_experts_are_quantized_to_fp8(self):\n        config = _make_tiny_deepseek_v31_config()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            model_path = root / \"model\"\n            adapter_path = root / \"adapter\"\n            output_path = root / \"merged\"\n\n            # Create model in BF16 to match real DeepSeek checkpoint format\n            _save_model_to_disk(config, model_path, tokenizer_name=\"deepseek-ai/DeepSeek-V3.1\")\n            _copy_hf_files(\n                \"deepseek-ai/DeepSeek-V3.1\",\n                model_path,\n                (\"configuration_deepseek.py\", \"modeling_deepseek.py\"),\n            )\n            # Re-save weights in BF16 (from_config creates float32 by default)\n            orig = AutoModelForCausalLM.from_pretrained(model_path, dtype=torch.bfloat16)\n            save_file(\n                {k: v.to(torch.bfloat16) for k, v in orig.state_dict().items()},\n                str(model_path / \"model.safetensors\"),\n            )\n            orig = AutoModelForCausalLM.from_pretrained(model_path, dtype=torch.bfloat16)\n            num_experts = 2\n\n            gate_shape = orig.state_dict()[\"model.layers.0.mlp.experts.0.gate_proj.weight\"].shape\n            expert_out_dim, expert_in_dim = gate_shape\n            dense_shape = orig.state_dict()[\"model.layers.0.self_attn.q_a_proj.weight\"].shape\n            dense_out_dim, dense_in_dim = dense_shape\n            dense_key = \"model.layers.0.self_attn.q_a_proj.weight\"\n            shared_expert_key = \"model.layers.0.mlp.shared_experts.gate_proj.weight\"\n            gate_keys = [\n                f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\" for i in range(num_experts)\n            ]\n            up_keys = [f\"model.layers.0.mlp.experts.{i}.up_proj.weight\" for i in range(num_experts)]\n\n            reference_weight_map = _reshard_saved_model(\n                model_path,\n                shard_assignments={\n                    dense_key: \"model-00001-of-00002.safetensors\",\n                    shared_expert_key: \"model-00002-of-00002.safetensors\",\n                    gate_keys[0]: \"model-00001-of-00002.safetensors\",\n                    up_keys[0]: \"model-00002-of-00002.safetensors\",\n                    gate_keys[1]: \"model-00002-of-00002.safetensors\",\n                    up_keys[1]: \"model-00001-of-00002.safetensors\",\n                },\n            )\n\n            _save_mixed_deepseek_adapter(\n                adapter_path,\n                num_experts=num_experts,\n                expert_in_dim=expert_in_dim,\n                expert_out_dim=expert_out_dim,\n                dense_in_dim=dense_in_dim,\n                dense_out_dim=dense_out_dim,\n            )\n\n            build_hf_model(\n                base_model=str(model_path),\n                adapter_path=str(adapter_path),\n                output_path=str(output_path),\n                quantize=\"experts-fp8\",\n                serving_format=\"vllm\",\n            )\n\n            saved_sd = _load_saved_state_dict(output_path)\n            saved_index = json.loads((output_path / \"model.safetensors.index.json\").read_text())\n            saved_config = json.loads((output_path / \"config.json\").read_text())\n\n            # -- Custom files copied --\n            assert (output_path / \"configuration_deepseek.py\").exists()\n            assert (output_path / \"modeling_deepseek.py\").exists()\n            assert (output_path / \"model.safetensors.index.json\").exists()\n\n            # -- Dense weight: merged, BF16, shard preserved --\n            dense_delta = (\n                (saved_sd[dense_key].float() - orig.state_dict()[dense_key].float()).abs().sum()\n            )\n            assert dense_delta > 0, \"Dense q_a_proj weight was not updated\"\n            assert saved_sd[dense_key].dtype == torch.bfloat16\n            assert saved_index[\"weight_map\"][dense_key] == reference_weight_map[dense_key], (\n                \"Dense tensor should preserve reference shard placement\"\n            )\n\n            # -- Routed experts: merged, FP8, scale present, shard preserved --\n            for i in range(num_experts):\n                gate_key = f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\"\n                up_key = f\"model.layers.0.mlp.experts.{i}.up_proj.weight\"\n                gate_scale_key = gate_key.removesuffix(\".weight\") + \".weight_scale\"\n                up_scale_key = up_key.removesuffix(\".weight\") + \".weight_scale\"\n\n                assert saved_sd[gate_key].dtype == torch.float8_e4m3fn, (\n                    f\"Routed expert should be FP8: {gate_key}\"\n                )\n                assert saved_sd[gate_scale_key].dtype == torch.float32, (\n                    f\"Scale should be float32: {gate_scale_key}\"\n                )\n                assert saved_sd[up_scale_key].dtype == torch.float32\n                assert saved_index[\"weight_map\"][gate_key] == reference_weight_map[gate_key], (\n                    \"Routed expert should preserve reference shard placement\"\n                )\n                assert (\n                    saved_index[\"weight_map\"][gate_scale_key] == reference_weight_map[gate_key]\n                ), \"Scale should be in same shard as weight\"\n\n            # -- Shared experts: BF16, not quantized, shard preserved --\n            assert saved_sd[shared_expert_key].dtype == torch.bfloat16\n            assert (\n                saved_index[\"weight_map\"][shared_expert_key]\n                == reference_weight_map[shared_expert_key]\n            )\n\n            # -- No .weight_scale_inv in output (compressed-tensors convention) --\n            assert not any(key.endswith(\".weight_scale_inv\") for key in saved_sd), (\n                \"Should emit .weight_scale, not .weight_scale_inv\"\n            )\n\n            # -- Index consistency --\n            assert set(saved_index[\"weight_map\"]) == set(saved_sd)\n            shard_membership: dict[str, set[str]] = {}\n            for shard_path in sorted(output_path.glob(\"*.safetensors\")):\n                shard_membership[shard_path.name] = set(load_file(str(shard_path)).keys())\n            assert set(saved_index[\"weight_map\"].values()) == set(shard_membership)\n\n            # -- Compressed-tensors config --\n            cc = saved_config.get(\"compression_config\")\n            assert \"quantization_config\" not in saved_config\n            assert cc is not None\n            assert cc[\"quant_method\"] == \"compressed-tensors\"\n            assert cc[\"format\"] == \"float-quantized\"\n            assert cc[\"quantization_status\"] == \"compressed\"\n            assert cc[\"config_groups\"][\"group_0\"][\"targets\"] == [\"Linear\"]\n            assert cc[\"config_groups\"][\"group_0\"][\"weights\"][\"strategy\"] == \"block\"\n            assert cc[\"config_groups\"][\"group_0\"][\"weights\"][\"block_structure\"] == [128, 128]\n            assert cc[\"config_groups\"][\"group_0\"][\"input_activations\"][\"dynamic\"] is True\n\n            ignore = set(cc[\"ignore\"])\n            assert \"model.layers.0.self_attn.q_a_proj\" in ignore\n            assert \"model.layers.0.mlp.shared_experts.gate_proj\" in ignore\n            assert \"model.layers.0.mlp.experts.0.gate_proj\" not in ignore\n\n\n# ---------------------------------------------------------------------------\n# 5. Qwen3 dense — standard linear layers (no experts)\n# ---------------------------------------------------------------------------\n\n\ndef _make_tiny_qwen3_dense_config() -> PretrainedConfig:\n    config = AutoConfig.from_pretrained(\"Qwen/Qwen3-8B\", trust_remote_code=True)\n    config.num_hidden_layers = 1\n    config.hidden_size = 64\n    config.intermediate_size = 64\n    config.num_attention_heads = 2\n    config.num_key_value_heads = 2\n    if hasattr(config, \"layer_types\") and config.layer_types is not None:\n        config.layer_types = config.layer_types[:1]\n    return config\n\n\nclass TestQwen3Dense:\n    \"\"\"Qwen3 dense: standard MLP with gate_proj/up_proj (no experts).\"\"\"\n\n    def test_dense_linear_merge(self):\n        config = _make_tiny_qwen3_dense_config()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            model_path, adapter_path, output_path = (\n                root / \"model\",\n                root / \"adapter\",\n                root / \"merged\",\n            )\n\n            _save_model_to_disk(config, model_path, tokenizer_name=\"Qwen/Qwen3-8B\")\n            orig = AutoModelForCausalLM.from_pretrained(\n                model_path, trust_remote_code=True, dtype=torch.float32\n            )\n            orig_gate = orig.state_dict()[\"model.layers.0.mlp.gate_proj.weight\"].clone()\n\n            _save_dense_adapter(adapter_path, in_dim=64, out_dim=64, fill=FILL_A)\n            merged_sd = _run_build_and_reload(model_path, adapter_path, output_path)\n\n            delta = (merged_sd[\"model.layers.0.mlp.gate_proj.weight\"] - orig_gate).abs().sum()\n            assert delta > 0, \"Dense gate_proj not updated\"\n"
  },
  {
    "path": "tests/weights/test_lifecycle.py",
    "content": "\"\"\"End-to-end lifecycle test: train → save → download → build.\n\nTrains a tiny SFT model for 1 step, saves the checkpoint, downloads\nthe adapter via weights.download(), and builds a merged HF model via\nweights.build_hf_model().\n\nRequires TINKER_API_KEY and network access. Skipped otherwise.\n\"\"\"\n\nimport asyncio\nimport tempfile\nfrom pathlib import Path\nfrom typing import cast\n\nimport datasets\nimport pytest\nimport tinker\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.supervised.data import (\n    SupervisedDatasetFromHFDataset,\n    conversation_to_datum,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.weights import build_hf_model, download\n\n\n@pytest.mark.integration\nclass TestFullLifecycle:\n    \"\"\"Train 1 step → save → download → build merged HF model.\"\"\"\n\n    MODEL_NAME = \"Qwen/Qwen3-8B\"\n    RENDERER_NAME = \"qwen3\"\n    BATCH_SIZE = 4\n    MAX_LENGTH = 512\n    LORA_RANK = 8\n\n    def _train_and_save(self, log_path: str) -> str:\n        \"\"\"Train for 1 step and return the sampler checkpoint tinker:// path.\"\"\"\n        tokenizer = get_tokenizer(self.MODEL_NAME)\n        renderer = renderers.get_renderer(self.RENDERER_NAME, tokenizer)\n\n        # Load a small slice of data\n        dataset = datasets.load_dataset(\"allenai/tulu-3-sft-mixture\")\n        dataset = cast(datasets.DatasetDict, dataset)\n        train_ds = dataset[\"train\"].take(self.BATCH_SIZE)\n\n        def map_fn(row: dict) -> tinker.Datum:\n            return conversation_to_datum(row[\"messages\"], renderer, self.MAX_LENGTH)\n\n        sft_dataset = SupervisedDatasetFromHFDataset(\n            train_ds, batch_size=self.BATCH_SIZE, map_fn=map_fn\n        )\n\n        async def _run() -> str:\n            sc = tinker.ServiceClient()\n            tc = await sc.create_lora_training_client_async(\n                base_model=self.MODEL_NAME,\n                rank=self.LORA_RANK,\n            )\n\n            # Train 1 step\n            batch = sft_dataset.get_batch(0)\n            fwd_bwd = await tc.forward_backward_async(batch, loss_fn=\"cross_entropy\")\n            await fwd_bwd.result_async()\n            optim = await tc.optim_step_async({\"learning_rate\": 1e-4})\n            await optim.result_async()\n\n            # Save checkpoint\n            sampler_resp = await tc.save_weights_for_sampler_async(\"lifecycle_test\")\n            result = await sampler_resp.result_async()\n            return result.path\n\n        return asyncio.run(_run())\n\n    def test_train_download_build(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            log_path = str(root / \"logs\")\n            Path(log_path).mkdir()\n\n            # Step 1: Train and save\n            tinker_path = self._train_and_save(log_path)\n            assert tinker_path.startswith(\"tinker://\"), f\"Unexpected path format: {tinker_path}\"\n\n            # Step 2: Download\n            adapter_dir = download(\n                tinker_path=tinker_path,\n                output_dir=str(root / \"adapter\"),\n            )\n            adapter_path = Path(adapter_dir)\n            assert (adapter_path / \"adapter_model.safetensors\").exists(), (\n                f\"adapter_model.safetensors not found in {adapter_dir}\"\n            )\n            assert (adapter_path / \"adapter_config.json\").exists(), (\n                f\"adapter_config.json not found in {adapter_dir}\"\n            )\n\n            # Step 3: Build merged HF model\n            output_path = str(root / \"merged\")\n            build_hf_model(\n                base_model=self.MODEL_NAME,\n                adapter_path=adapter_dir,\n                output_path=output_path,\n            )\n\n            # Verify output looks like a valid HF model\n            out = Path(output_path)\n            assert (out / \"config.json\").exists(), \"config.json missing from merged model\"\n            assert any(out.glob(\"*.safetensors\")), \"No safetensors files in merged model\"\n            assert (out / \"tokenizer.json\").exists() or (out / \"tokenizer_config.json\").exists(), (\n                \"Tokenizer files missing from merged model\"\n            )\n"
  },
  {
    "path": "tests/weights/test_publish.py",
    "content": "\"\"\"Integration test for weights.publish_to_hf_hub().\n\nRequires HF authentication (HF_TOKEN env var or `hf auth login`).\nSkipped otherwise.\n\nCreates a temporary private repo, uploads a tiny dummy model, verifies\nthe upload, and cleans up the repo regardless of test outcome.\n\"\"\"\n\nimport contextlib\nimport json\nimport tempfile\nimport uuid\nfrom pathlib import Path\n\nimport pytest\n\nfrom tinker_cookbook.weights import publish_to_hf_hub\n\n\ndef _hf_username() -> str:\n    \"\"\"Get the authenticated HF username, or skip the test.\"\"\"\n    try:\n        from huggingface_hub import HfApi\n\n        api = HfApi()\n        info = api.whoami()\n        return info[\"name\"]\n    except Exception:\n        pytest.skip(\"HF authentication required (set HF_TOKEN or run `hf auth login`)\")\n        return \"\"  # unreachable, keeps type checker happy\n\n\ndef _create_dummy_model_dir(path: Path) -> None:\n    \"\"\"Create a minimal directory that looks like an HF model.\"\"\"\n    path.mkdir(parents=True)\n    (path / \"config.json\").write_text(json.dumps({\"model_type\": \"test\"}))\n    (path / \"README.md\").write_text(\"Test model for tinker_cookbook integration test\")\n\n\n@pytest.mark.integration\nclass TestPublishToHfHubIntegration:\n    def test_upload_and_verify(self):\n        username = _hf_username()\n        repo_id = f\"{username}/tinker-cookbook-test-{uuid.uuid4().hex[:8]}\"\n\n        from huggingface_hub import HfApi\n\n        api = HfApi()\n\n        try:\n            with tempfile.TemporaryDirectory() as tmpdir:\n                model_path = Path(tmpdir) / \"model\"\n                _create_dummy_model_dir(model_path)\n\n                url = publish_to_hf_hub(\n                    model_path=str(model_path),\n                    repo_id=repo_id,\n                    private=True,\n                )\n\n                assert url == f\"https://huggingface.co/{repo_id}\"\n\n                # Verify the repo exists and has our files\n                files = api.list_repo_files(repo_id=repo_id, repo_type=\"model\")\n                assert \"config.json\" in files\n                assert \"README.md\" in files\n        finally:\n            # Always clean up, even if test fails\n            with contextlib.suppress(Exception):\n                api.delete_repo(repo_id=repo_id, repo_type=\"model\")\n"
  },
  {
    "path": "tests/weights/test_quantized.py",
    "content": "\"\"\"End-to-end tests for quantized export (DeepSeek FP8).\n\nUses a tiny 1-layer DeepSeek V3 model created from config with synthetic\nrandom weights. Tests exercise the full pipeline including merge, quantize,\nshard layout preservation, config patching, and resume.\n\"\"\"\n\nimport json\nimport math\nfrom pathlib import Path\nfrom unittest.mock import patch\n\nimport pytest\nimport torch\nfrom safetensors.torch import load_file, save_file\n\nfrom tinker_cookbook.weights import build_hf_model\n\n# ---------------------------------------------------------------------------\n# Tiny DeepSeek model fixture\n# ---------------------------------------------------------------------------\n\n_HIDDEN = 64\n_INTER = 128\n_NUM_EXPERTS = 2\n_VOCAB = 256\n\n\ndef _deepseek_config() -> dict:\n    \"\"\"Minimal DeepSeek V3 config.\"\"\"\n    return {\n        \"model_type\": \"deepseek_v3\",\n        \"architectures\": [\"DeepseekV3ForCausalLM\"],\n        \"hidden_size\": _HIDDEN,\n        \"intermediate_size\": _INTER,\n        \"num_hidden_layers\": 1,\n        \"num_attention_heads\": 2,\n        \"num_key_value_heads\": 2,\n        \"n_routed_experts\": _NUM_EXPERTS,\n        \"vocab_size\": _VOCAB,\n    }\n\n\ndef _deepseek_state_dict() -> dict[str, torch.Tensor]:\n    \"\"\"Create synthetic weights for a tiny 1-layer DeepSeek V3 model.\"\"\"\n    sd: dict[str, torch.Tensor] = {}\n\n    # Embedding\n    sd[\"model.embed_tokens.weight\"] = torch.randn(_VOCAB, _HIDDEN, dtype=torch.bfloat16)\n\n    # Attention\n    for proj in (\"q_a_proj\", \"q_b_proj\", \"kv_a_proj_with_mqa\", \"kv_b_proj\"):\n        # Simplified dims — don't need to match real DeepSeek exactly\n        sd[f\"model.layers.0.self_attn.{proj}.weight\"] = torch.randn(\n            _HIDDEN, _HIDDEN, dtype=torch.bfloat16\n        )\n\n    # Layer norms\n    sd[\"model.layers.0.input_layernorm.weight\"] = torch.ones(_HIDDEN, dtype=torch.bfloat16)\n    sd[\"model.layers.0.post_attention_layernorm.weight\"] = torch.ones(_HIDDEN, dtype=torch.bfloat16)\n\n    # Router\n    sd[\"model.layers.0.mlp.gate.weight\"] = torch.randn(_NUM_EXPERTS, _HIDDEN, dtype=torch.bfloat16)\n\n    # Routed experts\n    for i in range(_NUM_EXPERTS):\n        for proj, shape in [\n            (\"gate_proj\", (_INTER, _HIDDEN)),\n            (\"up_proj\", (_INTER, _HIDDEN)),\n            (\"down_proj\", (_HIDDEN, _INTER)),\n        ]:\n            sd[f\"model.layers.0.mlp.experts.{i}.{proj}.weight\"] = torch.randn(\n                *shape, dtype=torch.bfloat16\n            )\n\n    # Shared experts\n    for proj, shape in [\n        (\"gate_proj\", (_INTER, _HIDDEN)),\n        (\"up_proj\", (_INTER, _HIDDEN)),\n        (\"down_proj\", (_HIDDEN, _INTER)),\n    ]:\n        sd[f\"model.layers.0.mlp.shared_experts.{proj}.weight\"] = torch.randn(\n            *shape, dtype=torch.bfloat16\n        )\n\n    # LM head\n    sd[\"lm_head.weight\"] = torch.randn(_VOCAB, _HIDDEN, dtype=torch.bfloat16)\n\n    # Final norm\n    sd[\"model.norm.weight\"] = torch.ones(_HIDDEN, dtype=torch.bfloat16)\n\n    return sd\n\n\ndef _split_into_shards(\n    sd: dict[str, torch.Tensor],\n) -> dict[str, dict[str, torch.Tensor]]:\n    \"\"\"Split state dict into 2 shards: attention+embed in shard 1, MLP+rest in shard 2.\"\"\"\n    shard1: dict[str, torch.Tensor] = {}\n    shard2: dict[str, torch.Tensor] = {}\n\n    for key, tensor in sd.items():\n        if \"self_attn\" in key or \"embed_tokens\" in key or \"input_layernorm\" in key:\n            shard1[key] = tensor\n        else:\n            shard2[key] = tensor\n\n    return {\n        \"model-00001-of-00002.safetensors\": shard1,\n        \"model-00002-of-00002.safetensors\": shard2,\n    }\n\n\ndef _create_deepseek_model(model_dir: Path, shards: dict[str, dict[str, torch.Tensor]]) -> None:\n    \"\"\"Write a sharded DeepSeek model to disk.\"\"\"\n    model_dir.mkdir(parents=True, exist_ok=True)\n    config = _deepseek_config()\n    (model_dir / \"config.json\").write_text(json.dumps(config))\n\n    weight_map: dict[str, str] = {}\n    for shard_name, tensors in shards.items():\n        save_file(tensors, str(model_dir / shard_name))\n        for key in tensors:\n            weight_map[key] = shard_name\n\n    total_size = sum(t.nelement() * t.element_size() for s in shards.values() for t in s.values())\n    index = {\"metadata\": {\"total_size\": total_size}, \"weight_map\": weight_map}\n    (model_dir / \"model.safetensors.index.json\").write_text(json.dumps(index))\n\n    # Minimal tokenizer\n    (model_dir / \"tokenizer_config.json\").write_text(\n        json.dumps({\"tokenizer_class\": \"PreTrainedTokenizerFast\"})\n    )\n    (model_dir / \"tokenizer.json\").write_text(\n        json.dumps(\n            {\n                \"version\": \"1.0\",\n                \"model\": {\"type\": \"BPE\", \"vocab\": {\"a\": 0, \"b\": 1}, \"merges\": []},\n                \"added_tokens\": [],\n            }\n        )\n    )\n\n    # Custom model code files (for copy_model_code_files test)\n    (model_dir / \"configuration_deepseek.py\").write_text(\"# DeepSeek config\\n\")\n    (model_dir / \"modeling_deepseek.py\").write_text(\"# DeepSeek model\\n\")\n\n\ndef _create_deepseek_adapter(adapter_dir: Path) -> None:\n    \"\"\"Create a LoRA adapter targeting attention and expert weights.\"\"\"\n    adapter_dir.mkdir(parents=True, exist_ok=True)\n\n    rank = 1\n    weights: dict[str, torch.Tensor] = {}\n\n    # Target attention q_a_proj\n    weights[\"base_model.model.model.layers.0.self_attn.q_a_proj.lora_A.weight\"] = (\n        torch.ones(rank, _HIDDEN) * 0.01\n    )\n    weights[\"base_model.model.model.layers.0.self_attn.q_a_proj.lora_B.weight\"] = torch.ones(\n        _HIDDEN, rank\n    )\n\n    # Target routed experts gate_proj (w1)\n    weights[\"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\"] = (\n        torch.ones(_NUM_EXPERTS, rank, _HIDDEN) * 0.01\n    )\n    weights[\"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\"] = torch.ones(\n        _NUM_EXPERTS, _INTER, rank\n    )\n\n    save_file(weights, str(adapter_dir / \"adapter_model.safetensors\"))\n    (adapter_dir / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": rank}))\n\n\n@pytest.fixture\ndef deepseek_model(tmp_path: Path):\n    \"\"\"Set up a tiny DeepSeek model + adapter.\"\"\"\n    sd = _deepseek_state_dict()\n    shards = _split_into_shards(sd)\n    model_dir = tmp_path / \"model\"\n    adapter_dir = tmp_path / \"adapter\"\n\n    _create_deepseek_model(model_dir, shards)\n    _create_deepseek_adapter(adapter_dir)\n\n    return model_dir, adapter_dir, sd\n\n\ndef _load_output(output_dir: Path) -> dict[str, torch.Tensor]:\n    \"\"\"Load all output tensors from a sharded output directory.\"\"\"\n    index_path = output_dir / \"model.safetensors.index.json\"\n    if index_path.exists():\n        with open(index_path) as f:\n            weight_map = json.load(f)[\"weight_map\"]\n        tensors: dict[str, torch.Tensor] = {}\n        for shard_name in sorted(set(weight_map.values())):\n            tensors.update(load_file(str(output_dir / shard_name)))\n        return tensors\n    single = output_dir / \"model.safetensors\"\n    assert single.exists()\n    return load_file(str(single))\n\n\n# ---------------------------------------------------------------------------\n# Branch 1: Dense weights (attention/embedding)\n# ---------------------------------------------------------------------------\n\n\nclass TestDenseWeights:\n    def test_dense_weights_change_after_merge(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, orig_sd = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        key = \"model.layers.0.self_attn.q_a_proj.weight\"\n        delta = (out[key].float() - orig_sd[key].float()).abs().sum()\n        assert delta > 0, \"q_a_proj should have changed after merge\"\n\n    def test_dense_weights_stay_bf16(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        assert out[\"model.layers.0.self_attn.q_a_proj.weight\"].dtype == torch.bfloat16\n\n    def test_untargeted_dense_unchanged(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, orig_sd = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        key = \"model.embed_tokens.weight\"\n        assert torch.equal(out[key], orig_sd[key]), \"embed_tokens should be bit-identical\"\n\n\n# ---------------------------------------------------------------------------\n# Branch 2: Routed expert weights\n# ---------------------------------------------------------------------------\n\n\nclass TestRoutedExperts:\n    def test_routed_experts_change_after_merge(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, orig_sd = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        from tinker_cookbook.weights._export._quantized import dequantize_blockwise\n\n        key = \"model.layers.0.mlp.experts.0.gate_proj.weight\"\n        scale_key = key.replace(\".weight\", \".weight_scale\")\n        merged = dequantize_blockwise(out[key], out[scale_key], dtype=torch.bfloat16)\n        delta = (merged.float() - orig_sd[key].float()).abs().sum()\n        assert delta > 0, \"Expert gate_proj should have changed after merge\"\n\n    def test_routed_experts_quantized_to_fp8(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        key = \"model.layers.0.mlp.experts.0.gate_proj.weight\"\n        assert out[key].dtype == torch.float8_e4m3fn\n\n    def test_expert_has_float32_scale(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        scale_key = \"model.layers.0.mlp.experts.0.gate_proj.weight_scale\"\n        assert scale_key in out\n        assert out[scale_key].dtype == torch.float32\n\n    def test_scale_shape_matches_block_structure(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        scale = out[\"model.layers.0.mlp.experts.0.gate_proj.weight_scale\"]\n        # gate_proj shape is (_INTER, _HIDDEN) = (128, 64)\n        # block_size = 128, so scale should be ceil(128/128) x ceil(64/128) = (1, 1)\n        expected = (math.ceil(_INTER / 128), math.ceil(_HIDDEN / 128))\n        assert scale.shape == expected\n\n\n# ---------------------------------------------------------------------------\n# Branch 3: Shared expert weights\n# ---------------------------------------------------------------------------\n\n\nclass TestSharedExperts:\n    def test_shared_experts_stay_bf16(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        key = \"model.layers.0.mlp.shared_experts.gate_proj.weight\"\n        assert out[key].dtype == torch.bfloat16\n\n    def test_shared_experts_no_scale(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        out = _load_output(output_dir)\n        assert \"model.layers.0.mlp.shared_experts.gate_proj.weight_scale\" not in out\n\n\n# ---------------------------------------------------------------------------\n# Branch 4: Shard layout preservation\n# ---------------------------------------------------------------------------\n\n\nclass TestShardLayout:\n    def test_two_shard_input_produces_two_shard_output(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        index_path = output_dir / \"model.safetensors.index.json\"\n        assert index_path.exists()\n        with open(index_path) as f:\n            index = json.load(f)\n        shard_files = set(index[\"weight_map\"].values())\n        assert len(shard_files) == 2\n\n    def test_index_consistent(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        index_path = output_dir / \"model.safetensors.index.json\"\n        with open(index_path) as f:\n            index = json.load(f)\n\n        # All listed files should exist\n        for shard_file in set(index[\"weight_map\"].values()):\n            assert (output_dir / shard_file).exists(), f\"Missing shard: {shard_file}\"\n\n        # All keys in weight_map should exist in corresponding shard\n        for key, shard_file in index[\"weight_map\"].items():\n            shard_tensors = load_file(str(output_dir / shard_file))\n            assert key in shard_tensors, f\"Key {key} not in {shard_file}\"\n\n\n# ---------------------------------------------------------------------------\n# Branch 5: Config and metadata\n# ---------------------------------------------------------------------------\n\n\nclass TestConfigMetadata:\n    def test_compression_config_present(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        config = json.loads((output_dir / \"config.json\").read_text())\n        assert \"compression_config\" in config\n        cc = config[\"compression_config\"]\n        assert cc[\"quant_method\"] == \"compressed-tensors\"\n        assert cc[\"format\"] == \"float-quantized\"\n\n    def test_quantization_config_absent(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        # Add a quantization_config to the input config to verify it gets removed\n        input_config = json.loads((model_dir / \"config.json\").read_text())\n        input_config[\"quantization_config\"] = {\"quant_method\": \"fp8\"}\n        (model_dir / \"config.json\").write_text(json.dumps(input_config))\n\n        output_dir = tmp_path / \"output\"\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        config = json.loads((output_dir / \"config.json\").read_text())\n        assert \"quantization_config\" not in config\n\n    def test_ignore_list_correct(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        config = json.loads((output_dir / \"config.json\").read_text())\n        ignore = config[\"compression_config\"][\"ignore\"]\n        # Dense projections should be in ignore\n        assert \"model.layers.0.self_attn.q_a_proj\" in ignore\n        # Routed experts should NOT be in ignore\n        routed_in_ignore = [\n            x for x in ignore if \".mlp.experts.\" in x and \".shared_experts.\" not in x\n        ]\n        assert len(routed_in_ignore) == 0, f\"Routed experts in ignore: {routed_in_ignore}\"\n\n    def test_model_code_files_copied(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        assert (output_dir / \"configuration_deepseek.py\").exists()\n        assert (output_dir / \"modeling_deepseek.py\").exists()\n\n\n# ---------------------------------------------------------------------------\n# Branch 6: Resume (crash + restart)\n# ---------------------------------------------------------------------------\n\n\nclass TestResume:\n    def test_crash_after_shard_1_shows_in_progress(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        # Monkeypatch to crash after first shard is saved\n        call_count = 0\n        original_load = load_file\n\n        def crash_on_second_shard(path, *args, **kwargs):\n            nonlocal call_count\n            result = original_load(path, *args, **kwargs)\n            # Count loads from model_dir (not adapter or output)\n            if str(model_dir) in str(path) and \"model-\" in str(path):\n                call_count += 1\n                if call_count >= 2:\n                    raise RuntimeError(\"Simulated crash\")\n            return result\n\n        with patch(\n            \"tinker_cookbook.weights._export._quantized.load_file\",\n            side_effect=crash_on_second_shard,\n        ):\n            with pytest.raises(RuntimeError, match=\"Simulated crash\"):\n                build_hf_model(\n                    base_model=str(model_dir),\n                    adapter_path=str(adapter_dir),\n                    output_path=str(output_dir),\n                    quantize=\"experts-fp8\",\n                    serving_format=\"vllm\",\n                )\n\n        # Check merge state\n        state = json.loads((output_dir / \"merge_state.json\").read_text())\n        assert state[\"status\"] == \"in_progress\"\n        assert len(state[\"completed_shards\"]) == 1\n\n    def test_resume_completes(self, tmp_path: Path, deepseek_model):\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        # First run: crash after shard 1\n        call_count = 0\n        original_load = load_file\n\n        def crash_on_second_shard(path, *args, **kwargs):\n            nonlocal call_count\n            result = original_load(path, *args, **kwargs)\n            if str(model_dir) in str(path) and \"model-\" in str(path):\n                call_count += 1\n                if call_count >= 2:\n                    raise RuntimeError(\"Simulated crash\")\n            return result\n\n        with patch(\n            \"tinker_cookbook.weights._export._quantized.load_file\",\n            side_effect=crash_on_second_shard,\n        ):\n            with pytest.raises(RuntimeError):\n                build_hf_model(\n                    base_model=str(model_dir),\n                    adapter_path=str(adapter_dir),\n                    output_path=str(output_dir),\n                    quantize=\"experts-fp8\",\n                    serving_format=\"vllm\",\n                )\n\n        # Second run: resume should complete\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            quantize=\"experts-fp8\",\n            serving_format=\"vllm\",\n        )\n\n        state = json.loads((output_dir / \"merge_state.json\").read_text())\n        assert state[\"status\"] == \"completed\"\n        assert len(state[\"completed_shards\"]) == 2\n\n\n# ---------------------------------------------------------------------------\n# Branch 7: New API interactions\n# ---------------------------------------------------------------------------\n\n\nclass TestApiValidation:\n    def test_quantize_none_does_standard_merge(self, tmp_path: Path, deepseek_model):\n        \"\"\"quantize=None with DeepSeek model should do standard BF16 merge.\"\"\"\n        model_dir, adapter_dir, _ = deepseek_model\n        output_dir = tmp_path / \"output\"\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n        )\n\n        out = _load_output(output_dir)\n        # No FP8 tensors\n        for key, tensor in out.items():\n            assert tensor.dtype != torch.float8_e4m3fn, f\"{key} should not be FP8\"\n        # No compression_config\n        config = json.loads((output_dir / \"config.json\").read_text())\n        assert \"compression_config\" not in config\n\n    def test_quantize_without_serving_format_raises(self, tmp_path: Path):\n        with pytest.raises(ValueError, match=\"serving_format\"):\n            build_hf_model(\n                base_model=str(tmp_path),\n                adapter_path=str(tmp_path),\n                output_path=str(tmp_path / \"out\"),\n                quantize=\"experts-fp8\",\n            )\n\n    def test_serving_format_without_quantize_raises(self, tmp_path: Path):\n        with pytest.raises(ValueError, match=\"quantize\"):\n            build_hf_model(\n                base_model=str(tmp_path),\n                adapter_path=str(tmp_path),\n                output_path=str(tmp_path / \"out\"),\n                serving_format=\"vllm\",\n            )\n\n    def test_quantize_with_wrong_dtype_raises(self, tmp_path: Path):\n        with pytest.raises(ValueError, match=\"bfloat16\"):\n            build_hf_model(\n                base_model=str(tmp_path),\n                adapter_path=str(tmp_path),\n                output_path=str(tmp_path / \"out\"),\n                quantize=\"experts-fp8\",\n                serving_format=\"vllm\",\n                dtype=\"float32\",\n            )\n\n    def test_unknown_quantize_raises(self, tmp_path: Path):\n        with pytest.raises(ValueError, match=\"quantize\"):\n            build_hf_model(\n                base_model=str(tmp_path),\n                adapter_path=str(tmp_path),\n                output_path=str(tmp_path / \"out\"),\n                quantize=\"unknown\",\n                serving_format=\"vllm\",\n            )\n\n    def test_unknown_serving_format_raises(self, tmp_path: Path):\n        with pytest.raises(ValueError, match=\"serving_format\"):\n            build_hf_model(\n                base_model=str(tmp_path),\n                adapter_path=str(tmp_path),\n                output_path=str(tmp_path / \"out\"),\n                quantize=\"experts-fp8\",\n                serving_format=\"unknown\",\n            )\n"
  },
  {
    "path": "tests/weights/test_quantized_equivalence.py",
    "content": "\"\"\"Equivalence tests: verify our quantized export matches PR #470 behavior.\n\nPR #470 (tinker_cookbook/weights/_deepseek.py) established the reference behavior\nfor DeepSeek FP8 export. This test suite verifies that our reimplementation\nin _export/_quantized.py produces equivalent output.\n\nUses a tiny 1-layer DeepSeek V3 model with synthetic weights — no network needed.\n\"\"\"\n\nimport json\nimport math\nfrom pathlib import Path\n\nimport pytest\nimport torch\nfrom safetensors.torch import load_file, save_file\n\nfrom tinker_cookbook.weights import build_hf_model\n\n# ---------------------------------------------------------------------------\n# Constants matching PR #470\n# ---------------------------------------------------------------------------\n\n_HIDDEN = 64\n_INTER = 128  # moe_intermediate_size in real DeepSeek, but simplified here\n_NUM_EXPERTS = 2\n_VOCAB = 256\n_BLOCK_SIZE = 128  # DeepSeek native FP8 block size\n\n# PR #470 used these suffixes to determine what goes in the ignore list\n_LINEAR_PROJ_SUFFIXES = (\n    \".q_proj.weight\",\n    \".q_a_proj.weight\",\n    \".q_b_proj.weight\",\n    \".kv_a_proj_with_mqa.weight\",\n    \".kv_b_proj.weight\",\n    \".o_proj.weight\",\n    \".gate_proj.weight\",\n    \".up_proj.weight\",\n    \".down_proj.weight\",\n)\n\n\n# ---------------------------------------------------------------------------\n# Test model setup\n# ---------------------------------------------------------------------------\n\n\ndef _create_test_model(model_dir: Path) -> dict[str, torch.Tensor]:\n    \"\"\"Create a tiny sharded DeepSeek V3 model matching PR #470 test structure.\"\"\"\n    sd: dict[str, torch.Tensor] = {}\n\n    # Embedding\n    sd[\"model.embed_tokens.weight\"] = torch.randn(_VOCAB, _HIDDEN, dtype=torch.bfloat16)\n\n    # Attention (using DeepSeek-specific projection names)\n    for proj in (\"q_a_proj\", \"q_b_proj\", \"kv_a_proj_with_mqa\", \"kv_b_proj\"):\n        sd[f\"model.layers.0.self_attn.{proj}.weight\"] = torch.randn(\n            _HIDDEN, _HIDDEN, dtype=torch.bfloat16\n        )\n\n    # Layer norms\n    sd[\"model.layers.0.input_layernorm.weight\"] = torch.ones(_HIDDEN, dtype=torch.bfloat16)\n    sd[\"model.layers.0.post_attention_layernorm.weight\"] = torch.ones(_HIDDEN, dtype=torch.bfloat16)\n\n    # Router\n    sd[\"model.layers.0.mlp.gate.weight\"] = torch.randn(_NUM_EXPERTS, _HIDDEN, dtype=torch.bfloat16)\n\n    # Routed experts (gate, up, down for each expert)\n    for i in range(_NUM_EXPERTS):\n        sd[f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\"] = torch.randn(\n            _INTER, _HIDDEN, dtype=torch.bfloat16\n        )\n        sd[f\"model.layers.0.mlp.experts.{i}.up_proj.weight\"] = torch.randn(\n            _INTER, _HIDDEN, dtype=torch.bfloat16\n        )\n        sd[f\"model.layers.0.mlp.experts.{i}.down_proj.weight\"] = torch.randn(\n            _HIDDEN, _INTER, dtype=torch.bfloat16\n        )\n\n    # Shared experts\n    sd[\"model.layers.0.mlp.shared_experts.gate_proj.weight\"] = torch.randn(\n        _INTER, _HIDDEN, dtype=torch.bfloat16\n    )\n    sd[\"model.layers.0.mlp.shared_experts.up_proj.weight\"] = torch.randn(\n        _INTER, _HIDDEN, dtype=torch.bfloat16\n    )\n    sd[\"model.layers.0.mlp.shared_experts.down_proj.weight\"] = torch.randn(\n        _HIDDEN, _INTER, dtype=torch.bfloat16\n    )\n\n    # LM head\n    sd[\"lm_head.weight\"] = torch.randn(_VOCAB, _HIDDEN, dtype=torch.bfloat16)\n\n    # Final norm\n    sd[\"model.norm.weight\"] = torch.ones(_HIDDEN, dtype=torch.bfloat16)\n\n    # Reshard into 2 shards (matching PR #470's test pattern)\n    shard1_keys = {\n        \"model.layers.0.self_attn.q_a_proj.weight\",\n        \"model.layers.0.mlp.experts.0.gate_proj.weight\",\n        \"model.layers.0.mlp.experts.1.up_proj.weight\",\n        \"model.embed_tokens.weight\",\n        \"model.layers.0.input_layernorm.weight\",\n    }\n\n    shard1 = {k: v for k, v in sd.items() if k in shard1_keys}\n    shard2 = {k: v for k, v in sd.items() if k not in shard1_keys}\n\n    shards = {\n        \"model-00001-of-00002.safetensors\": shard1,\n        \"model-00002-of-00002.safetensors\": shard2,\n    }\n\n    model_dir.mkdir(parents=True, exist_ok=True)\n    config = {\n        \"model_type\": \"deepseek_v3\",\n        \"architectures\": [\"DeepseekV3ForCausalLM\"],\n        \"hidden_size\": _HIDDEN,\n        \"intermediate_size\": _INTER,\n        \"num_hidden_layers\": 1,\n        \"num_attention_heads\": 2,\n        \"num_key_value_heads\": 2,\n        \"n_routed_experts\": _NUM_EXPERTS,\n        \"vocab_size\": _VOCAB,\n    }\n    (model_dir / \"config.json\").write_text(json.dumps(config))\n\n    weight_map: dict[str, str] = {}\n    total_size = 0\n    for shard_name, tensors in shards.items():\n        save_file(tensors, str(model_dir / shard_name))\n        for key, tensor in tensors.items():\n            weight_map[key] = shard_name\n            total_size += tensor.nelement() * tensor.element_size()\n\n    index = {\"metadata\": {\"total_size\": total_size}, \"weight_map\": weight_map}\n    (model_dir / \"model.safetensors.index.json\").write_text(json.dumps(index))\n\n    # Tokenizer\n    (model_dir / \"tokenizer_config.json\").write_text(\n        json.dumps({\"tokenizer_class\": \"PreTrainedTokenizerFast\"})\n    )\n    (model_dir / \"tokenizer.json\").write_text(\n        json.dumps(\n            {\n                \"version\": \"1.0\",\n                \"model\": {\"type\": \"BPE\", \"vocab\": {\"a\": 0, \"b\": 1}, \"merges\": []},\n                \"added_tokens\": [],\n            }\n        )\n    )\n\n    # Custom model code\n    (model_dir / \"configuration_deepseek.py\").write_text(\"# config\\n\")\n    (model_dir / \"modeling_deepseek.py\").write_text(\"# model\\n\")\n\n    return sd\n\n\ndef _create_test_adapter(adapter_dir: Path) -> None:\n    \"\"\"Create adapter targeting both dense and expert weights (matching PR #470).\"\"\"\n    adapter_dir.mkdir(parents=True, exist_ok=True)\n    rank = 1\n    weights: dict[str, torch.Tensor] = {}\n\n    # Dense: q_a_proj\n    weights[\"base_model.model.model.layers.0.self_attn.q_a_proj.lora_A.weight\"] = (\n        torch.ones(rank, _HIDDEN, dtype=torch.bfloat16) * 0.01\n    )\n    weights[\"base_model.model.model.layers.0.self_attn.q_a_proj.lora_B.weight\"] = torch.ones(\n        _HIDDEN, rank, dtype=torch.bfloat16\n    )\n\n    # Expert gate (w1)\n    weights[\"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\"] = (\n        torch.ones(_NUM_EXPERTS, rank, _HIDDEN, dtype=torch.bfloat16) * 0.01\n    )\n    weights[\"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\"] = torch.ones(\n        _NUM_EXPERTS, _INTER, rank, dtype=torch.bfloat16\n    )\n\n    # Expert up (w3)\n    weights[\"base_model.model.model.layers.0.mlp.experts.w3.lora_A.weight\"] = (\n        torch.ones(_NUM_EXPERTS, rank, _HIDDEN, dtype=torch.bfloat16) * 0.05\n    )\n    weights[\"base_model.model.model.layers.0.mlp.experts.w3.lora_B.weight\"] = torch.ones(\n        _NUM_EXPERTS, _INTER, rank, dtype=torch.bfloat16\n    )\n\n    save_file(weights, str(adapter_dir / \"adapter_model.safetensors\"))\n    (adapter_dir / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": rank}))\n\n\ndef _load_all_output_tensors(output_dir: Path) -> dict[str, torch.Tensor]:\n    \"\"\"Load all output tensors from safetensors files.\"\"\"\n    tensors: dict[str, torch.Tensor] = {}\n    for path in sorted(output_dir.glob(\"*.safetensors\")):\n        tensors.update(load_file(str(path)))\n    return tensors\n\n\n@pytest.fixture\ndef equivalence_model(tmp_path: Path):\n    \"\"\"Set up test model + adapter + run build.\"\"\"\n    model_dir = tmp_path / \"model\"\n    adapter_dir = tmp_path / \"adapter\"\n    output_dir = tmp_path / \"output\"\n\n    orig_sd = _create_test_model(model_dir)\n    _create_test_adapter(adapter_dir)\n\n    build_hf_model(\n        base_model=str(model_dir),\n        adapter_path=str(adapter_dir),\n        output_path=str(output_dir),\n        quantize=\"experts-fp8\",\n        serving_format=\"vllm\",\n    )\n\n    saved_sd = _load_all_output_tensors(output_dir)\n    saved_config = json.loads((output_dir / \"config.json\").read_text())\n    saved_index = json.loads((output_dir / \"model.safetensors.index.json\").read_text())\n\n    return {\n        \"orig_sd\": orig_sd,\n        \"saved_sd\": saved_sd,\n        \"saved_config\": saved_config,\n        \"saved_index\": saved_index,\n        \"output_dir\": output_dir,\n        \"model_dir\": model_dir,\n    }\n\n\n# ---------------------------------------------------------------------------\n# 1. Scale tensor naming: must use .weight_scale (not .weight_scale_inv)\n#    PR #470 uses compressed-tensors convention: .weight_scale\n# ---------------------------------------------------------------------------\n\n\nclass TestScaleTensorNaming:\n    def test_routed_expert_scales_use_weight_scale_name(self, equivalence_model):\n        \"\"\"PR #470 emits .weight_scale, not .weight_scale_inv.\"\"\"\n        sd = equivalence_model[\"saved_sd\"]\n\n        for i in range(_NUM_EXPERTS):\n            for proj in (\"gate_proj\", \"up_proj\", \"down_proj\"):\n                key = f\"model.layers.0.mlp.experts.{i}.{proj}.weight_scale\"\n                inv_key = f\"model.layers.0.mlp.experts.{i}.{proj}.weight_scale_inv\"\n                assert key in sd, f\"Expected {key} in output (PR #470 convention)\"\n                assert inv_key not in sd, f\"Should not have {inv_key} (DeepSeek native convention)\"\n\n    def test_no_weight_scale_inv_in_output(self, equivalence_model):\n        \"\"\"PR #470 explicitly asserts: no .weight_scale_inv in output.\"\"\"\n        sd = equivalence_model[\"saved_sd\"]\n        inv_keys = [k for k in sd if k.endswith(\".weight_scale_inv\")]\n        assert not inv_keys, f\"Found .weight_scale_inv keys (should be .weight_scale): {inv_keys}\"\n\n\n# ---------------------------------------------------------------------------\n# 2. Compressed-tensors config schema\n#    PR #470 uses strategy=\"block\" with block_structure=[128, 128]\n#    and includes input_activations with dynamic=True\n# ---------------------------------------------------------------------------\n\n\nclass TestCompressedTensorsConfig:\n    def test_weights_strategy_is_block(self, equivalence_model):\n        \"\"\"PR #470: config_groups.group_0.weights.strategy == 'block'.\"\"\"\n        cc = equivalence_model[\"saved_config\"][\"compression_config\"]\n        weights = cc[\"config_groups\"][\"group_0\"][\"weights\"]\n        assert weights[\"strategy\"] == \"block\", (\n            f\"Expected strategy='block' (PR #470), got {weights.get('strategy')!r}\"\n        )\n\n    def test_block_structure_present(self, equivalence_model):\n        \"\"\"PR #470: config_groups.group_0.weights.block_structure == [128, 128].\"\"\"\n        cc = equivalence_model[\"saved_config\"][\"compression_config\"]\n        weights = cc[\"config_groups\"][\"group_0\"][\"weights\"]\n        assert weights.get(\"block_structure\") == [128, 128], (\n            f\"Expected block_structure=[128, 128], got {weights.get('block_structure')}\"\n        )\n\n    def test_input_activations_dynamic(self, equivalence_model):\n        \"\"\"PR #470: input_activations with dynamic=True.\"\"\"\n        cc = equivalence_model[\"saved_config\"][\"compression_config\"]\n        group = cc[\"config_groups\"][\"group_0\"]\n        assert \"input_activations\" in group, \"Missing input_activations section\"\n        assert group[\"input_activations\"][\"dynamic\"] is True, (\n            \"input_activations.dynamic should be True\"\n        )\n\n    def test_quant_method_compressed_tensors(self, equivalence_model):\n        cc = equivalence_model[\"saved_config\"][\"compression_config\"]\n        assert cc[\"quant_method\"] == \"compressed-tensors\"\n\n    def test_format_float_quantized(self, equivalence_model):\n        cc = equivalence_model[\"saved_config\"][\"compression_config\"]\n        assert cc[\"format\"] == \"float-quantized\"\n\n    def test_quantization_status_compressed(self, equivalence_model):\n        cc = equivalence_model[\"saved_config\"][\"compression_config\"]\n        assert cc[\"quantization_status\"] == \"compressed\"\n\n    def test_quantization_config_absent(self, equivalence_model):\n        assert \"quantization_config\" not in equivalence_model[\"saved_config\"]\n\n\n# ---------------------------------------------------------------------------\n# 3. Ignore list: PR #470 uses _LINEAR_PROJ_SUFFIXES, not all .weight keys\n# ---------------------------------------------------------------------------\n\n\nclass TestIgnoreList:\n    def test_dense_projections_in_ignore(self, equivalence_model):\n        \"\"\"Dense linear projections must be in ignore list.\"\"\"\n        ignore = set(equivalence_model[\"saved_config\"][\"compression_config\"][\"ignore\"])\n        assert \"model.layers.0.self_attn.q_a_proj\" in ignore\n        assert \"model.layers.0.self_attn.q_b_proj\" in ignore\n        assert \"model.layers.0.self_attn.kv_a_proj_with_mqa\" in ignore\n        assert \"model.layers.0.self_attn.kv_b_proj\" in ignore\n\n    def test_shared_experts_in_ignore(self, equivalence_model):\n        ignore = set(equivalence_model[\"saved_config\"][\"compression_config\"][\"ignore\"])\n        assert \"model.layers.0.mlp.shared_experts.gate_proj\" in ignore\n        assert \"model.layers.0.mlp.shared_experts.up_proj\" in ignore\n        assert \"model.layers.0.mlp.shared_experts.down_proj\" in ignore\n\n    def test_routed_experts_not_in_ignore(self, equivalence_model):\n        ignore = set(equivalence_model[\"saved_config\"][\"compression_config\"][\"ignore\"])\n        for i in range(_NUM_EXPERTS):\n            for proj in (\"gate_proj\", \"up_proj\", \"down_proj\"):\n                assert f\"model.layers.0.mlp.experts.{i}.{proj}\" not in ignore\n\n    def test_lm_head_in_ignore(self, equivalence_model):\n        \"\"\"PR #470 explicitly adds lm_head to ignore if not quantized.\"\"\"\n        ignore = set(equivalence_model[\"saved_config\"][\"compression_config\"][\"ignore\"])\n        assert \"lm_head\" in ignore\n\n\n# ---------------------------------------------------------------------------\n# 4. Weight dtype and merge correctness\n# ---------------------------------------------------------------------------\n\n\nclass TestWeightDtypes:\n    def test_dense_weights_bf16(self, equivalence_model):\n        sd = equivalence_model[\"saved_sd\"]\n        assert sd[\"model.layers.0.self_attn.q_a_proj.weight\"].dtype == torch.bfloat16\n\n    def test_routed_experts_fp8(self, equivalence_model):\n        sd = equivalence_model[\"saved_sd\"]\n        for i in range(_NUM_EXPERTS):\n            for proj in (\"gate_proj\", \"up_proj\", \"down_proj\"):\n                key = f\"model.layers.0.mlp.experts.{i}.{proj}.weight\"\n                assert sd[key].dtype == torch.float8_e4m3fn, f\"{key} should be FP8\"\n\n    def test_routed_expert_scales_float32(self, equivalence_model):\n        sd = equivalence_model[\"saved_sd\"]\n        for i in range(_NUM_EXPERTS):\n            for proj in (\"gate_proj\", \"up_proj\", \"down_proj\"):\n                key = f\"model.layers.0.mlp.experts.{i}.{proj}.weight_scale\"\n                assert sd[key].dtype == torch.float32, f\"{key} should be float32\"\n\n    def test_shared_experts_bf16(self, equivalence_model):\n        sd = equivalence_model[\"saved_sd\"]\n        for proj in (\"gate_proj\", \"up_proj\", \"down_proj\"):\n            key = f\"model.layers.0.mlp.shared_experts.{proj}.weight\"\n            assert sd[key].dtype == torch.bfloat16\n\n    def test_dense_weight_changed_after_merge(self, equivalence_model):\n        orig = equivalence_model[\"orig_sd\"]\n        saved = equivalence_model[\"saved_sd\"]\n        key = \"model.layers.0.self_attn.q_a_proj.weight\"\n        delta = (saved[key].float() - orig[key].float()).abs().sum()\n        assert delta > 0\n\n    def test_untargeted_embedding_unchanged(self, equivalence_model):\n        orig = equivalence_model[\"orig_sd\"]\n        saved = equivalence_model[\"saved_sd\"]\n        assert torch.equal(saved[\"model.embed_tokens.weight\"], orig[\"model.embed_tokens.weight\"])\n\n\n# ---------------------------------------------------------------------------\n# 5. FP8 scale shape matches block structure\n# ---------------------------------------------------------------------------\n\n\nclass TestScaleShapes:\n    def test_scale_shape_matches_blockwise_quantization(self, equivalence_model):\n        \"\"\"Scale shape must be ceil(rows/128) x ceil(cols/128).\"\"\"\n        sd = equivalence_model[\"saved_sd\"]\n        # gate_proj: (_INTER, _HIDDEN) = (128, 64)\n        scale = sd[\"model.layers.0.mlp.experts.0.gate_proj.weight_scale\"]\n        expected = (math.ceil(_INTER / _BLOCK_SIZE), math.ceil(_HIDDEN / _BLOCK_SIZE))\n        assert scale.shape == expected, f\"Expected {expected}, got {tuple(scale.shape)}\"\n\n    def test_down_proj_scale_shape(self, equivalence_model):\n        sd = equivalence_model[\"saved_sd\"]\n        # down_proj: (_HIDDEN, _INTER) = (64, 128)\n        scale = sd[\"model.layers.0.mlp.experts.0.down_proj.weight_scale\"]\n        expected = (math.ceil(_HIDDEN / _BLOCK_SIZE), math.ceil(_INTER / _BLOCK_SIZE))\n        assert scale.shape == expected\n\n\n# ---------------------------------------------------------------------------\n# 6. Shard layout and index consistency\n# ---------------------------------------------------------------------------\n\n\nclass TestShardConsistency:\n    def test_two_shard_output(self, equivalence_model):\n        index = equivalence_model[\"saved_index\"]\n        assert len(set(index[\"weight_map\"].values())) == 2\n\n    def test_index_covers_all_tensors(self, equivalence_model):\n        \"\"\"PR #470: weight_map should cover every emitted tensor exactly once.\"\"\"\n        index = equivalence_model[\"saved_index\"]\n        sd = equivalence_model[\"saved_sd\"]\n        assert set(index[\"weight_map\"]) == set(sd)\n\n    def test_all_shards_referenced_and_exist(self, equivalence_model):\n        index = equivalence_model[\"saved_index\"]\n        output_dir = equivalence_model[\"output_dir\"]\n        for shard_file in set(index[\"weight_map\"].values()):\n            assert (output_dir / shard_file).exists()\n\n    def test_scale_tensors_in_same_shard_as_weights(self, equivalence_model):\n        \"\"\"PR #470: scale tensors should be alongside their weight tensors.\"\"\"\n        wm = equivalence_model[\"saved_index\"][\"weight_map\"]\n        for i in range(_NUM_EXPERTS):\n            for proj in (\"gate_proj\", \"up_proj\", \"down_proj\"):\n                weight_key = f\"model.layers.0.mlp.experts.{i}.{proj}.weight\"\n                scale_key = f\"model.layers.0.mlp.experts.{i}.{proj}.weight_scale\"\n                if weight_key in wm and scale_key in wm:\n                    assert wm[weight_key] == wm[scale_key], (\n                        f\"Scale {scale_key} not in same shard as {weight_key}\"\n                    )\n\n\n# ---------------------------------------------------------------------------\n# 7. Custom files copied\n# ---------------------------------------------------------------------------\n\n\nclass TestCustomFiles:\n    def test_configuration_deepseek_copied(self, equivalence_model):\n        assert (equivalence_model[\"output_dir\"] / \"configuration_deepseek.py\").exists()\n\n    def test_modeling_deepseek_copied(self, equivalence_model):\n        assert (equivalence_model[\"output_dir\"] / \"modeling_deepseek.py\").exists()\n"
  },
  {
    "path": "tests/weights/test_strategy_consistency.py",
    "content": "\"\"\"Verify that merge_strategy='full' and merge_strategy='shard' produce identical output.\n\nUses a tiny real Qwen3 dense model to exercise the full pipeline including\ntokenizer saving, config handling, and actual HF model loading.\n\nRequires network access to download Qwen3-8B config + tokenizer on first run\n(cached by HF Hub afterwards).\n\"\"\"\n\nimport json\nimport tempfile\nfrom pathlib import Path\n\nimport torch\nfrom safetensors.torch import load_file, save_file\nfrom transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig\n\nfrom tinker_cookbook.weights import build_hf_model\n\nFILL = 0.01\n\n\ndef _make_tiny_qwen3_dense_config() -> PretrainedConfig:\n    config = AutoConfig.from_pretrained(\"Qwen/Qwen3-8B\", trust_remote_code=True)\n    config.num_hidden_layers = 1\n    config.hidden_size = 64\n    config.intermediate_size = 64\n    config.num_attention_heads = 2\n    config.num_key_value_heads = 2\n    if hasattr(config, \"layer_types\") and config.layer_types is not None:\n        config.layer_types = config.layer_types[:1]\n    return config\n\n\ndef _save_model_to_disk(config: PretrainedConfig, path: Path) -> None:\n    # Save in bfloat16 to match real-world models. This ensures both full\n    # (which loads as bfloat16 by default) and shard (which preserves on-disk\n    # dtype) paths work with the same precision.\n    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, dtype=torch.bfloat16)\n    model.save_pretrained(path)\n    tok = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-8B\", trust_remote_code=True)\n    tok.save_pretrained(path)\n\n\ndef _save_adapter(path: Path, *, model_path: Path) -> None:\n    \"\"\"Create adapter with LoRA weights matching the model's actual dimensions.\"\"\"\n    model = AutoModelForCausalLM.from_pretrained(\n        model_path, trust_remote_code=True, dtype=torch.bfloat16\n    )\n    sd = model.state_dict()\n\n    rank = 1\n    gate_shape = sd[\"model.layers.0.mlp.gate_proj.weight\"].shape  # (out, in)\n    q_shape = sd[\"model.layers.0.self_attn.q_proj.weight\"].shape  # (out, in)\n\n    weights = {\n        \"base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight\": (\n            torch.ones(rank, gate_shape[1]) * FILL\n        ),\n        \"base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight\": torch.ones(\n            gate_shape[0], rank\n        ),\n        \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": (\n            torch.ones(rank, q_shape[1]) * FILL\n        ),\n        \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.ones(\n            q_shape[0], rank\n        ),\n    }\n    path.mkdir(parents=True)\n    save_file(weights, str(path / \"adapter_model.safetensors\"))\n    (path / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": rank}))\n\n\ndef _load_all_tensors(output_dir: Path) -> dict[str, torch.Tensor]:\n    \"\"\"Load all safetensors from output directory.\"\"\"\n    single = output_dir / \"model.safetensors\"\n    if single.exists():\n        return load_file(str(single))\n    index_path = output_dir / \"model.safetensors.index.json\"\n    with open(index_path) as f:\n        weight_map = json.load(f)[\"weight_map\"]\n    tensors: dict[str, torch.Tensor] = {}\n    for shard_name in sorted(set(weight_map.values())):\n        tensors.update(load_file(str(output_dir / shard_name)))\n    return tensors\n\n\nclass TestStrategyConsistency:\n    \"\"\"Verify full and shard strategies produce identical merged weights.\"\"\"\n\n    def test_full_and_shard_produce_identical_weights(self):\n        config = _make_tiny_qwen3_dense_config()\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            model_path = root / \"model\"\n            adapter_path = root / \"adapter\"\n            output_full = root / \"merged_full\"\n            output_shard = root / \"merged_shard\"\n\n            _save_model_to_disk(config, model_path)\n            _save_adapter(adapter_path, model_path=model_path)\n\n            # Run both strategies\n            build_hf_model(\n                base_model=str(model_path),\n                adapter_path=str(adapter_path),\n                output_path=str(output_full),\n                merge_strategy=\"full\",\n            )\n            build_hf_model(\n                base_model=str(model_path),\n                adapter_path=str(adapter_path),\n                output_path=str(output_shard),\n                merge_strategy=\"shard\",\n            )\n\n            # Load both outputs\n            full_tensors = _load_all_tensors(output_full)\n            shard_tensors = _load_all_tensors(output_shard)\n\n            # Same keys\n            assert set(full_tensors.keys()) == set(shard_tensors.keys()), (\n                f\"Key mismatch: \"\n                f\"full_only={set(full_tensors.keys()) - set(shard_tensors.keys())}, \"\n                f\"shard_only={set(shard_tensors.keys()) - set(full_tensors.keys())}\"\n            )\n\n            # Same values (bit-identical)\n            mismatches = []\n            for key in sorted(full_tensors.keys()):\n                if not torch.equal(full_tensors[key], shard_tensors[key]):\n                    max_diff = (full_tensors[key].float() - shard_tensors[key].float()).abs().max()\n                    mismatches.append(f\"{key}: max_diff={max_diff:.6e}\")\n\n            assert not mismatches, (\n                \"Weight mismatches between full and shard strategies:\\n\" + \"\\n\".join(mismatches)\n            )\n\n            # Both should have config.json\n            assert (output_full / \"config.json\").exists()\n            assert (output_shard / \"config.json\").exists()\n\n            # Verify the merge actually changed something\n            orig = AutoModelForCausalLM.from_pretrained(\n                model_path, trust_remote_code=True, dtype=torch.bfloat16\n            )\n            orig_gate = orig.state_dict()[\"model.layers.0.mlp.gate_proj.weight\"]\n            merged_gate = full_tensors[\"model.layers.0.mlp.gate_proj.weight\"]\n            delta = (merged_gate.float() - orig_gate.float()).abs().sum()\n            assert delta > 0, \"Merge did not modify gate_proj\"\n"
  },
  {
    "path": "tinker_cookbook/__init__.py",
    "content": "\"\"\"Tinker Cookbook: post-training algorithms using the Tinker API.\"\"\"\n\ntry:\n    from tinker_cookbook._version import __version__\nexcept ImportError:\n    try:\n        from importlib.metadata import version\n\n        __version__ = version(\"tinker_cookbook\")\n    except Exception:\n        __version__ = \"0.0.0.dev0+unknown\"\n\nfrom tinker_cookbook.exceptions import (\n    AllTrajectoriesFailedError,\n    CheckpointError,\n    ConfigurationError,\n    DataError,\n    DataFormatError,\n    DataValidationError,\n    RendererError,\n    SandboxError,\n    TinkerCookbookError,\n    TrainingError,\n    WeightsDownloadError,\n    WeightsError,\n    WeightsMergeError,\n)\n\n__all__ = [\n    \"__version__\",\n    \"AllTrajectoriesFailedError\",\n    \"CheckpointError\",\n    \"ConfigurationError\",\n    \"DataError\",\n    \"DataFormatError\",\n    \"DataValidationError\",\n    \"RendererError\",\n    \"SandboxError\",\n    \"TinkerCookbookError\",\n    \"TrainingError\",\n    \"WeightsDownloadError\",\n    \"WeightsError\",\n    \"WeightsMergeError\",\n]\n"
  },
  {
    "path": "tinker_cookbook/chat_app/README.md",
    "content": "# Tinker Chat CLI\n\nThis README provides instructions for chatting with models trained using **Tinker**.\n\n---\n\n## Getting Started\n\nYou can easily chat with any sampler checkpoint saved using **Tinker** by running the following command:\n\n```bash\npython -m tinker_cookbook.chat_app.tinker_chat_cli \\\n    model_path=tinker://<unique_id>/sampler_weights/final \\\n    base_model=meta-llama/Llama-3.1-8B\n```\n\n### Arguments\n\n* **model_path**: Path to the trained Tinker sampler checkpoint. Example: `tinker://<unique_id>/sampler_weights/final`. Note that the Tinker chat CLI will not work with training weights which look like `tinker://<unique_id>/weights/final`. Make sure the checkpoint contains `sampler_weights`.\n* **base_model**: Hugging Face base model to use for inference. Example: `meta-llama/Llama-3.1-8B`\n\n---\n\n## Customization\n\nYou can modify the behavior of the chat by providing additional arguments:\n\n* **max_tokens** *(int, default=512)*\n  Maximum number of tokens to generate in the response.\n\n* **temperature** *(float, default=0.7)*\n  Controls the randomness of the output. Higher values = more random responses.\n\n* **top_p** *(float, default=0.9)*\n  Controls nucleus sampling. The model considers only the top tokens with cumulative probability `p`.\n\nExample:\n\n```bash\npython -m tinker_cookbook.chat_app.tinker_chat_cli \\\n    model_path=tinker://<unique_id>/sampler_weights/final \\\n    base_model=meta-llama/Llama-3.1-8B \\\n    max_tokens=256 \\\n    temperature=0.8 \\\n    top_p=0.95\n```\n"
  },
  {
    "path": "tinker_cookbook/chat_app/tinker_chat_cli.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nSimple CLI chat interface using tinker sampling client.\n\"\"\"\n\nimport asyncio\nimport logging\nimport os\nimport sys\n\nimport chz\nimport tinker\nfrom tinker import types\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.model_info import get_recommended_renderer_name\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nlogger = logging.getLogger(__name__)\nlogging.basicConfig(\n    format=\"%(asctime)s %(levelname)-8s %(filename)s:%(lineno)-4s %(message)s\",\n    level=logging.WARNING,\n    datefmt=\"%Y-%m-%d %H:%M:%S\",\n)\n\n\n@chz.chz\nclass Config:\n    base_model: str = \"meta-llama/Llama-3.2-1B\"\n    model_path: str | None = None\n    max_tokens: int = 512\n    temperature: float = 0.7\n    top_p: float = 0.9\n    base_url: str | None = None\n\n\nclass ChatSession:\n    \"\"\"Manages a chat session with conversation history.\"\"\"\n\n    def __init__(\n        self,\n        sampling_client: tinker.SamplingClient,\n        renderer: renderers.Renderer,\n        max_tokens: int,\n        temperature: float,\n        top_p: float,\n    ):\n        self.sampling_client: tinker.SamplingClient = sampling_client\n        self.renderer: renderers.Renderer = renderer\n        self.max_tokens: int = max_tokens\n        self.temperature: float = temperature\n        self.top_p: float = top_p\n        self.conversation_history: list[renderers.Message] = []\n\n    def add_user_message(self, content: str):\n        \"\"\"Add a user message to the conversation history.\"\"\"\n        self.conversation_history.append({\"role\": \"user\", \"content\": content})\n\n    def add_assistant_message(self, content: str):\n        \"\"\"Add an assistant message to the conversation history.\"\"\"\n        self.conversation_history.append({\"role\": \"assistant\", \"content\": content})\n\n    def clear_history(self):\n        \"\"\"Clear the conversation history.\"\"\"\n        self.conversation_history.clear()\n\n    async def generate_response(self) -> str:\n        \"\"\"Generate a response from the model.\"\"\"\n        try:\n            # Build the model input from conversation history\n            model_input = self.renderer.build_generation_prompt(self.conversation_history)\n\n            # Set up sampling parameters\n            sampling_params = types.SamplingParams(\n                max_tokens=self.max_tokens,\n                temperature=self.temperature,\n                top_p=self.top_p,\n                stop=self.renderer.get_stop_sequences(),\n            )\n\n            # Generate response\n            response = await self.sampling_client.sample_async(\n                prompt=model_input, num_samples=1, sampling_params=sampling_params\n            )\n\n            # Parse the response\n            parsed_message, _ = self.renderer.parse_response(response.sequences[0].tokens)\n            content = parsed_message[\"content\"]\n\n            # Format content for display\n            display_content = renderers.format_content_as_string(content, separator=\"\\n--------\\n\")\n\n            self.add_assistant_message(display_content)\n            return display_content\n\n        except Exception as e:\n            logger.error(f\"Error generating response: {e}\")\n            return f\"Error: {e}\"\n\n\nasync def main(config: Config):\n    \"\"\"Main chat loop.\"\"\"\n\n    print(f\"🚀 Initializing chat with model: {config.base_model}\")\n    print(f\"📦 Using Path: {config.model_path}\")\n\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n    try:\n        # Create service client\n        service_client = tinker.ServiceClient(base_url=config.base_url)\n\n        # Create sampling client\n        sampling_client = service_client.create_sampling_client(\n            base_model=config.base_model,\n            model_path=config.model_path if config.model_path else None,\n        )\n\n        # Get tokenizer and renderer\n        tokenizer = get_tokenizer(config.base_model)\n        renderer = renderers.get_renderer(\n            get_recommended_renderer_name(config.base_model), tokenizer\n        )\n\n        # Create chat session\n        chat_session = ChatSession(\n            sampling_client=sampling_client,\n            renderer=renderer,\n            max_tokens=config.max_tokens,\n            temperature=config.temperature,\n            top_p=config.top_p,\n        )\n\n        print(\"\\n💬 Chat started! Type 'quit', 'exit', or Ctrl+C to end the session.\")\n        print(\"🗑️  Type 'n' to clear conversation history and start a new conversation.\")\n        print(\"🤖 You can start chatting now...\\n\")\n\n        # Main chat loop\n        while True:\n            try:\n                # Get user input\n                user_input = input(\"User: \").strip()\n\n                # Check for exit commands\n                if user_input.lower() in [\"quit\", \"exit\", \"q\"]:\n                    print(\"👋 Goodbye!\")\n                    break\n\n                # Check for clear history command\n                if user_input.lower() == \"n\":\n                    chat_session.clear_history()\n                    print(\"🗑️  Conversation history cleared! Starting a new conversation.\")\n                    continue\n\n                if not user_input:\n                    continue\n\n                # Add user message to conversation\n                chat_session.add_user_message(user_input)\n\n                # Generate and display response\n                print(\"Assistant: \", end=\"\", flush=True)\n                response = await chat_session.generate_response()\n                print(response)\n                print()  # Empty line for readability\n\n            except KeyboardInterrupt:\n                print(\"\\n👋 Chat interrupted. Goodbye!\")\n                break\n            except EOFError:\n                print(\"\\n👋 Chat ended. Goodbye!\")\n                break\n            except Exception as e:\n                print(f\"❌ Error: {e}\")\n                logger.exception(\"Unexpected error in chat loop\")\n\n    except Exception as e:\n        print(f\"❌ Failed to initialize chat: {e}\")\n        logger.exception(\"Failed to initialize chat\")\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    asyncio.run(chz.nested_entrypoint(main))\n"
  },
  {
    "path": "tinker_cookbook/checkpoint_utils.py",
    "content": "import asyncio\nimport dataclasses\nimport json\nimport logging\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any, Literal\n\nimport tinker\n\nfrom tinker_cookbook import model_info\nfrom tinker_cookbook.utils import trace\nfrom tinker_cookbook.utils.file_utils import read_jsonl\n\nCHECKPOINTS_BASE_NAME = \"checkpoints.jsonl\"\n\nlogger = logging.getLogger(__name__)\nRENDERER_NAME_METADATA_KEY = \"renderer_name\"\n\n\n_MISSING = object()  # sentinel for distinguishing \"not set\" from None\n\n\n@dataclass\nclass CheckpointRecord:\n    \"\"\"A single checkpoint record stored in ``checkpoints.jsonl``.\n\n    Known fields are exposed as typed attributes.  ``batch`` is optional so\n    that checkpoint files written by older code (or external tools that use\n    different progress counters) can still be loaded.\n\n    Any additional user-supplied metadata from ``loop_state`` is preserved in\n    :attr:`extra` so that custom keys round-trip through save/load without\n    loss.\n    \"\"\"\n\n    name: str\n    batch: int | None = None\n    epoch: int | None = None\n    final: bool | None = None\n    state_path: str | None = None\n    sampler_path: str | None = None\n    extra: dict[str, Any] = field(default_factory=dict)\n\n    def __post_init__(self) -> None:\n        # Defensive: if extra accidentally contains a known key (e.g. via\n        # direct construction), drop it so to_dict() never double-writes.\n        overlap = set(self.extra) & _CHECKPOINT_RECORD_KNOWN_KEYS\n        if overlap:\n            logger.warning(\"CheckpointRecord: dropping known keys from extra: %s\", overlap)\n            self.extra = {k: v for k, v in self.extra.items() if k not in overlap}\n\n    def to_dict(self) -> dict[str, Any]:\n        \"\"\"Serialize to a dict for JSON storage. Omits ``None`` optional fields.\"\"\"\n        d: dict[str, Any] = {\"name\": self.name}\n        if self.batch is not None:\n            d[\"batch\"] = self.batch\n        if self.epoch is not None:\n            d[\"epoch\"] = self.epoch\n        if self.final is not None:\n            d[\"final\"] = self.final\n        if self.state_path is not None:\n            d[\"state_path\"] = self.state_path\n        if self.sampler_path is not None:\n            d[\"sampler_path\"] = self.sampler_path\n        d.update(self.extra)\n        return d\n\n    @classmethod\n    def from_dict(cls, d: dict[str, Any]) -> \"CheckpointRecord\":\n        \"\"\"Deserialize from a JSON-parsed dict.\n\n        Unknown keys are preserved in :attr:`extra` so that downstream\n        metadata (e.g. ``step``) round-trips without loss.\n        \"\"\"\n        return cls(\n            name=d[\"name\"],\n            batch=d.get(\"batch\"),\n            epoch=d.get(\"epoch\"),\n            final=d.get(\"final\"),\n            state_path=d.get(\"state_path\"),\n            sampler_path=d.get(\"sampler_path\"),\n            extra={k: v for k, v in d.items() if k not in _CHECKPOINT_RECORD_KNOWN_KEYS},\n        )\n\n    def has(self, key: str) -> bool:\n        \"\"\"Check whether a field is present (not None), including extra keys.\"\"\"\n        if key in _CHECKPOINT_RECORD_KNOWN_KEYS:\n            return getattr(self, key) is not None\n        return key in self.extra\n\n    def get(self, key: str, default: Any = _MISSING) -> Any:\n        \"\"\"Get a field value by name, falling back to extra, then *default*.\n\n        This provides uniform access regardless of whether a key is a known\n        attribute or user-supplied metadata stored in :attr:`extra`.\n\n        For known fields, returns the attribute value (which may be ``None``\n        if the field is optional and unset).  Returns *default* only when the\n        key is not a known field **and** is absent from :attr:`extra`.\n        \"\"\"\n        if key in _CHECKPOINT_RECORD_KNOWN_KEYS:\n            return getattr(self, key)\n        if default is _MISSING:\n            return self.extra.get(key)\n        return self.extra.get(key, default)\n\n\n# Derived from the dataclass fields so it stays in sync automatically.\n# Excludes \"extra\" since that's the catch-all, not a serialized key.\n_CHECKPOINT_RECORD_KNOWN_KEYS = frozenset(\n    f.name for f in dataclasses.fields(CheckpointRecord) if f.name != \"extra\"\n)\n\n\ndef add_renderer_name_to_user_metadata(\n    user_metadata: dict[str, str], renderer_name: str | None\n) -> None:\n    \"\"\"Attach renderer name to training-run metadata when available.\"\"\"\n    if renderer_name:\n        user_metadata[RENDERER_NAME_METADATA_KEY] = renderer_name\n\n\ndef _handle_checkpoint_renderer_check_result(\n    checkpoint_path: str,\n    expected_renderer_name: str,\n    checkpoint_renderer_name: str | None,\n) -> None:\n    if checkpoint_renderer_name is None:\n        logger.info(\"Checkpoint %s has no renderer metadata.\", checkpoint_path)\n    elif checkpoint_renderer_name != expected_renderer_name:\n        logger.warning(\n            \"Renderer mismatch for checkpoint %s: checkpoint=%s current=%s\",\n            checkpoint_path,\n            checkpoint_renderer_name,\n            expected_renderer_name,\n        )\n    else:\n        logger.info(\n            \"Renderer metadata matches for checkpoint %s: %s\",\n            checkpoint_path,\n            expected_renderer_name,\n        )\n    return None\n\n\ndef get_renderer_name_from_checkpoint(\n    service_client: tinker.ServiceClient, checkpoint_path: str\n) -> str | None:\n    \"\"\"Read renderer_name metadata from the training run referenced by a checkpoint path.\"\"\"\n    try:\n        rest_client = service_client.create_rest_client()\n        training_run = rest_client.get_training_run_by_tinker_path(checkpoint_path).result()\n        return (training_run.user_metadata or {}).get(RENDERER_NAME_METADATA_KEY)\n    except (tinker.TinkerError, ValueError) as e:\n        logger.warning(\n            \"Could not fetch renderer metadata for checkpoint %s: %s\",\n            checkpoint_path,\n            e,\n        )\n        return None\n\n\nasync def get_renderer_name_from_checkpoint_async(\n    service_client: tinker.ServiceClient, checkpoint_path: str\n) -> str | None:\n    \"\"\"Async version of get_renderer_name_from_checkpoint.\"\"\"\n    try:\n        rest_client = service_client.create_rest_client()\n        training_run = await rest_client.get_training_run_by_tinker_path_async(checkpoint_path)\n        return (training_run.user_metadata or {}).get(RENDERER_NAME_METADATA_KEY)\n    except (tinker.TinkerError, ValueError) as e:\n        logger.warning(\n            \"Could not fetch renderer metadata for checkpoint %s: %s\",\n            checkpoint_path,\n            e,\n        )\n        return None\n\n\ndef resolve_renderer_name_from_checkpoint_or_default(\n    *,\n    model_name: str,\n    explicit_renderer_name: str | None,\n    load_checkpoint_path: str | None,\n    base_url: str | None = None,\n) -> str:\n    \"\"\"\n    Resolve renderer name for training/eval setup.\n\n    Precedence:\n    1) explicit renderer name, if provided\n    2) renderer metadata from load checkpoint path, if available\n    3) recommended renderer for model_name\n    \"\"\"\n    if explicit_renderer_name is not None:\n        return explicit_renderer_name\n\n    if load_checkpoint_path is not None:\n        service_client = tinker.ServiceClient(base_url=base_url)\n        renderer_name = get_renderer_name_from_checkpoint(service_client, load_checkpoint_path)\n        if renderer_name is not None:\n            logger.info(\n                \"Using renderer from checkpoint metadata for %s: %s\",\n                load_checkpoint_path,\n                renderer_name,\n            )\n            return renderer_name\n\n    return model_info.get_recommended_renderer_name(model_name)\n\n\nasync def resolve_renderer_name_from_checkpoint_or_default_async(\n    *,\n    model_name: str,\n    explicit_renderer_name: str | None,\n    load_checkpoint_path: str | None,\n    base_url: str | None = None,\n) -> str:\n    \"\"\"\n    Async version of resolve_renderer_name_from_checkpoint_or_default.\n    \"\"\"\n    if explicit_renderer_name is not None:\n        return explicit_renderer_name\n\n    if load_checkpoint_path is not None:\n        service_client = tinker.ServiceClient(base_url=base_url)\n        renderer_name = await get_renderer_name_from_checkpoint_async(\n            service_client, load_checkpoint_path\n        )\n        if renderer_name is not None:\n            logger.info(\n                \"Using renderer from checkpoint metadata for %s: %s\",\n                load_checkpoint_path,\n                renderer_name,\n            )\n            return renderer_name\n\n    return model_info.get_recommended_renderer_name(model_name)\n\n\ndef check_renderer_name_for_checkpoint(\n    service_client: tinker.ServiceClient,\n    checkpoint_path: str,\n    expected_renderer_name: str | None,\n) -> None:\n    \"\"\"\n    Inspect a checkpoint's originating training run metadata and compare renderer name.\n\n    \"\"\"\n    if expected_renderer_name is None:\n        return None\n\n    checkpoint_renderer_name = get_renderer_name_from_checkpoint(service_client, checkpoint_path)\n\n    _handle_checkpoint_renderer_check_result(\n        checkpoint_path, expected_renderer_name, checkpoint_renderer_name\n    )\n    return None\n\n\nasync def check_renderer_name_for_checkpoint_async(\n    service_client: tinker.ServiceClient,\n    checkpoint_path: str,\n    expected_renderer_name: str | None,\n) -> None:\n    \"\"\"\n    Compare an expected renderer with renderer metadata attached to a checkpoint's training run.\n\n    Behavior:\n    - If ``expected_renderer_name`` is None, returns None and does no check.\n    - Otherwise fetches ``renderer_name`` from the run referenced by ``checkpoint_path``.\n    - Logs info if metadata is missing or matches.\n    - Logs warning if the checkpoint renderer differs from the expected renderer.\n\n    \"\"\"\n    if expected_renderer_name is None:\n        return None\n\n    checkpoint_renderer_name = await get_renderer_name_from_checkpoint_async(\n        service_client, checkpoint_path\n    )\n\n    _handle_checkpoint_renderer_check_result(\n        checkpoint_path, expected_renderer_name, checkpoint_renderer_name\n    )\n    return None\n\n\n@trace.scope\ndef load_checkpoints_file(log_dir: str) -> list[CheckpointRecord]:\n    checkpoint_path = Path(log_dir) / CHECKPOINTS_BASE_NAME\n    if not checkpoint_path.exists():\n        logger.info(f\"No checkpoints found at {checkpoint_path}\")\n        return []\n\n    logger.info(f\"Reading checkpoints from {checkpoint_path}\")\n    trace.update_scope_context({\"checkpoint_path\": str(checkpoint_path)})\n    return [CheckpointRecord.from_dict(d) for d in read_jsonl(str(checkpoint_path))]\n\n\n@trace.scope\ndef get_last_checkpoint(log_dir: str, required_key: str = \"state_path\") -> CheckpointRecord | None:\n    \"\"\"\n    Get the last checkpoint from the checkpoints.jsonl file in the specified log directory.\n\n    Args:\n        log_dir: The directory to check.\n        required_key: The key to check for in the checkpoint.\n            We might save partial checkpoints (e.g. sampler) in the same file,\n            so we need to filter to the rows that have a fully-resumable checkpoint.\n\n    Returns:\n        The last checkpoint, or None if no checkpoint is found.\n    \"\"\"\n    checkpoints = load_checkpoints_file(log_dir)\n    checkpoints_with_key = [c for c in checkpoints if c.has(required_key)]\n    if checkpoints_with_key:\n        logger.info(\n            f\"Found {len(checkpoints_with_key)} valid checkpoints with key '{required_key}' in {log_dir}\"\n        )\n        logger.info(f\"Using last checkpoint: {checkpoints_with_key[-1]}\")\n        return checkpoints_with_key[-1]\n    else:\n        logger.info(f\"No checkpoints found with key {required_key} in {log_dir}\")\n        return None\n\n\n@trace.scope\nasync def save_checkpoint_async(\n    training_client: tinker.TrainingClient,\n    name: str,\n    log_path: str,\n    loop_state: dict[str, Any],\n    kind: Literal[\"state\", \"sampler\", \"both\"] = \"state\",\n    ttl_seconds: int | None = None,\n) -> dict[str, str]:\n    \"\"\"Save model checkpoint and append a record to ``checkpoints.jsonl``.\n\n    Args:\n        training_client: Training client to save from.\n        name: Name for the checkpoint (used in the tinker:// path).\n        log_path: Directory containing ``checkpoints.jsonl``.\n        loop_state: Training loop state. May include ``batch``, ``step``,\n            ``epoch``, ``final``, and any additional user metadata.\n        kind: Which checkpoint types to save.\n        ttl_seconds: Server-side retention. ``None`` keeps the checkpoint indefinitely.\n\n    Returns:\n        Dict mapping ``\"state_path\"`` and/or ``\"sampler_path\"`` to tinker:// paths.\n    \"\"\"\n    futures = {}\n    if kind in [\"state\", \"both\"]:\n        futures[\"state\"] = await training_client.save_state_async(name, ttl_seconds=ttl_seconds)\n    if kind in [\"sampler\", \"both\"]:\n        futures[\"sampler\"] = await training_client.save_weights_for_sampler_async(\n            name, ttl_seconds=ttl_seconds\n        )\n\n    results = {k: await v.result_async() for k, v in futures.items()}\n    paths = {k + \"_path\": v.path for k, v in results.items()}\n    trace.update_scope_context(paths)\n    logger.info(f\"Saved checkpoints: {paths}\")\n\n    record = CheckpointRecord.from_dict({\"name\": name, **loop_state, **paths})\n    with open(Path(log_path) / \"checkpoints.jsonl\", \"a\") as f:\n        f.write(json.dumps(record.to_dict()) + \"\\n\")\n\n    return paths\n\n\n@trace.scope\ndef save_checkpoint(\n    training_client: tinker.TrainingClient,\n    name: str,\n    log_path: str,\n    loop_state: dict[str, Any],\n    kind: Literal[\"state\", \"sampler\", \"both\"] = \"state\",\n    ttl_seconds: int | None = None,\n) -> dict[str, str]:\n    \"\"\"Save model checkpoint.\n    Args:\n        training_client: Training client to save from\n        name: Name for the checkpoint\n        log_path: Path to the log directory, where we can find checkpoints.jsonl file\n    Returns:\n        Path to the saved checkpoint\n    \"\"\"\n    return asyncio.run(\n        save_checkpoint_async(\n            training_client,\n            name=name,\n            log_path=log_path,\n            kind=kind,\n            loop_state=loop_state,\n            ttl_seconds=ttl_seconds,\n        )\n    )\n"
  },
  {
    "path": "tinker_cookbook/checkpoint_utils_test.py",
    "content": "\"\"\"Tests for checkpoint_utils path handling.\"\"\"\n\nimport json\nimport tempfile\nfrom pathlib import Path\n\nfrom tinker_cookbook.checkpoint_utils import (\n    CheckpointRecord,\n    get_last_checkpoint,\n    load_checkpoints_file,\n)\n\n\ndef _write_checkpoints_jsonl(log_dir: str, records: list[dict]) -> None:\n    path = Path(log_dir) / \"checkpoints.jsonl\"\n    with open(path, \"w\") as f:\n        for record in records:\n            f.write(json.dumps(record) + \"\\n\")\n\n\ndef test_load_checkpoints_file_missing_dir():\n    \"\"\"load_checkpoints_file returns [] when the directory doesn't exist.\"\"\"\n    result = load_checkpoints_file(\"/tmp/nonexistent_dir_abc123\")\n    assert result == []\n\n\ndef test_load_checkpoints_file_missing_file():\n    \"\"\"load_checkpoints_file returns [] when checkpoints.jsonl is absent.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        result = load_checkpoints_file(tmpdir)\n        assert result == []\n\n\ndef test_load_checkpoints_file_reads_records():\n    \"\"\"load_checkpoints_file reads and deserializes checkpoint records.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        _write_checkpoints_jsonl(\n            tmpdir,\n            [\n                {\"name\": \"000005\", \"batch\": 5, \"state_path\": \"tinker://state/5\"},\n                {\"name\": \"000010\", \"batch\": 10, \"state_path\": \"tinker://state/10\"},\n            ],\n        )\n        result = load_checkpoints_file(tmpdir)\n        assert len(result) == 2\n        assert isinstance(result[0], CheckpointRecord)\n        assert result[0].name == \"000005\"\n        assert result[1].batch == 10\n\n\ndef test_get_last_checkpoint_returns_last():\n    \"\"\"get_last_checkpoint returns the last record with the required key.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        _write_checkpoints_jsonl(\n            tmpdir,\n            [\n                {\"name\": \"000005\", \"batch\": 5, \"state_path\": \"tinker://state/5\"},\n                {\"name\": \"000010\", \"batch\": 10, \"sampler_path\": \"tinker://sampler/10\"},\n                {\"name\": \"000015\", \"batch\": 15, \"state_path\": \"tinker://state/15\"},\n            ],\n        )\n        result = get_last_checkpoint(tmpdir, required_key=\"state_path\")\n        assert result is not None\n        assert result.name == \"000015\"\n\n\ndef test_get_last_checkpoint_returns_none_when_empty():\n    \"\"\"get_last_checkpoint returns None when no checkpoints exist.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        result = get_last_checkpoint(tmpdir)\n        assert result is None\n\n\ndef test_get_last_checkpoint_returns_none_when_key_missing():\n    \"\"\"get_last_checkpoint returns None when no record has the required key.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        _write_checkpoints_jsonl(\n            tmpdir,\n            [{\"name\": \"000005\", \"batch\": 5, \"sampler_path\": \"tinker://sampler/5\"}],\n        )\n        result = get_last_checkpoint(tmpdir, required_key=\"state_path\")\n        assert result is None\n\n\ndef test_load_checkpoints_file_without_batch():\n    \"\"\"Entries without 'batch' should deserialize without error (backward compat).\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        _write_checkpoints_jsonl(\n            tmpdir,\n            [\n                {\"name\": \"000000\", \"step\": 0},\n                {\"name\": \"000010\", \"step\": 10, \"state_path\": \"tinker://state/10\"},\n            ],\n        )\n        result = load_checkpoints_file(tmpdir)\n        assert len(result) == 2\n        assert result[0].batch is None\n        assert result[0].extra[\"step\"] == 0\n        assert result[1].state_path == \"tinker://state/10\"\n\n\ndef test_checkpoint_record_extra_round_trips():\n    \"\"\"Unknown keys land in extra and survive to_dict/from_dict round-trip.\"\"\"\n    record = CheckpointRecord.from_dict(\n        {\"name\": \"000005\", \"batch\": 5, \"step\": 5, \"custom_key\": \"val\"}\n    )\n    assert record.extra == {\"step\": 5, \"custom_key\": \"val\"}\n    d = record.to_dict()\n    assert d[\"step\"] == 5\n    assert d[\"custom_key\"] == \"val\"\n    restored = CheckpointRecord.from_dict(d)\n    assert restored.extra == {\"step\": 5, \"custom_key\": \"val\"}\n\n\ndef test_checkpoint_record_name_only():\n    \"\"\"A minimal entry with only 'name' should deserialize (batch None).\"\"\"\n    record = CheckpointRecord.from_dict({\"name\": \"000000\"})\n    assert record.name == \"000000\"\n    assert record.batch is None\n\n\ndef test_checkpoint_record_get_known_field():\n    \"\"\"get() returns known field values, including None for unset optional fields.\"\"\"\n    record = CheckpointRecord(name=\"test\", batch=5, state_path=\"tinker://state/5\")\n    assert record.get(\"batch\") == 5\n    assert record.get(\"state_path\") == \"tinker://state/5\"\n    # Known fields always return the attribute value, even when None.\n    # This distinguishes \"field exists but is unset\" from \"key is unknown\".\n    assert record.get(\"epoch\") is None\n    assert record.get(\"epoch\", -1) is None\n\n\ndef test_checkpoint_record_get_extra_field():\n    \"\"\"get() falls through to extra for unknown keys.\"\"\"\n    record = CheckpointRecord(name=\"test\", extra={\"step\": 10, \"custom\": \"val\"})\n    assert record.get(\"step\") == 10\n    assert record.get(\"custom\") == \"val\"\n    assert record.get(\"missing\") is None\n    assert record.get(\"missing\", \"default\") == \"default\"\n\n\ndef test_checkpoint_record_has_extra_field():\n    \"\"\"has() works for both known fields and extra keys.\"\"\"\n    record = CheckpointRecord(name=\"test\", batch=5, extra={\"step\": 10})\n    assert record.has(\"batch\")\n    assert not record.has(\"epoch\")\n    assert record.has(\"step\")\n    assert not record.has(\"missing\")\n\n\ndef test_checkpoint_record_extra_overlap_with_known_keys():\n    \"\"\"Known keys in extra are dropped defensively to prevent to_dict() conflicts.\"\"\"\n    record = CheckpointRecord(name=\"test\", batch=5, extra={\"batch\": 99, \"custom\": \"val\"})\n    # \"batch\" should be stripped from extra; the attribute value (5) wins\n    assert record.batch == 5\n    assert \"batch\" not in record.extra\n    assert record.extra == {\"custom\": \"val\"}\n    # to_dict() should have batch=5, not 99\n    d = record.to_dict()\n    assert d[\"batch\"] == 5\n"
  },
  {
    "path": "tinker_cookbook/cli_utils.py",
    "content": "import logging\nimport shutil\nfrom pathlib import Path\nfrom typing import Literal\n\nfrom tinker_cookbook.exceptions import ConfigurationError\n\nlogger = logging.getLogger(__name__)\n\nLogdirBehavior = Literal[\"delete\", \"resume\", \"ask\", \"raise\"]\n\n\ndef check_log_dir(log_dir: str, behavior_if_exists: LogdirBehavior):\n    \"\"\"\n    Call this at the beginning of CLI entrypoint to training scripts. This handles\n    cases that occur if we're trying to log to a directory that already exists.\n    The user might want to resume, overwrite, or delete it.\n\n    Args:\n        log_dir: The directory to check.\n        behavior_if_exists: What to do if the log directory already exists.\n\n        \"ask\": Ask user if they want to delete the log directory.\n        \"resume\": Continue to the training loop, which means we'll try to resume from the last checkpoint.\n        \"delete\": Delete the log directory and start logging there.\n        \"raise\": Raise an error if the log directory already exists.\n\n    Returns:\n        None\n    \"\"\"\n    if Path(log_dir).exists():\n        if behavior_if_exists == \"delete\":\n            logger.info(\n                f\"Log directory {log_dir} already exists. Will delete it and start logging there.\"\n            )\n            shutil.rmtree(log_dir)\n        elif behavior_if_exists == \"ask\":\n            while True:\n                user_input = input(\n                    f\"Log directory {log_dir} already exists. What do you want to do? [delete, resume, exit]: \"\n                )\n                if user_input == \"delete\":\n                    shutil.rmtree(log_dir)\n                    return\n                elif user_input == \"resume\":\n                    return\n                elif user_input == \"exit\":\n                    exit(0)\n                else:\n                    logger.warning(\n                        f\"Invalid input: {user_input}. Please enter 'delete', 'resume', or 'exit'.\"\n                    )\n        elif behavior_if_exists == \"resume\":\n            return\n        elif behavior_if_exists == \"raise\":\n            raise ConfigurationError(f\"Log directory {log_dir} already exists. Will not delete it.\")\n        else:\n            raise AssertionError(f\"Invalid behavior_if_exists: {behavior_if_exists}\")\n    else:\n        logger.info(\n            f\"Log directory {log_dir} does not exist. Will create it and start logging there.\"\n        )\n"
  },
  {
    "path": "tinker_cookbook/cli_utils_test.py",
    "content": "\"\"\"Tests for cli_utils path handling.\"\"\"\n\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom tinker_cookbook.cli_utils import check_log_dir\n\n\ndef test_check_log_dir_nonexistent_is_noop():\n    \"\"\"check_log_dir does nothing when the directory doesn't exist.\"\"\"\n    check_log_dir(\"/tmp/nonexistent_dir_abc123\", \"raise\")\n\n\ndef test_check_log_dir_resume_keeps_directory():\n    \"\"\"check_log_dir with 'resume' leaves the directory intact.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        marker = Path(tmpdir) / \"keep_me.txt\"\n        marker.write_text(\"hello\")\n        check_log_dir(tmpdir, \"resume\")\n        assert marker.exists()\n\n\ndef test_check_log_dir_delete_removes_directory():\n    \"\"\"check_log_dir with 'delete' removes the directory.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        target = Path(tmpdir) / \"subdir\"\n        target.mkdir()\n        (target / \"file.txt\").write_text(\"hello\")\n        check_log_dir(str(target), \"delete\")\n        assert not target.exists()\n\n\ndef test_check_log_dir_raise_raises():\n    \"\"\"check_log_dir with 'raise' raises ValueError when directory exists.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        with pytest.raises(ValueError, match=\"already exists\"):\n            check_log_dir(tmpdir, \"raise\")\n"
  },
  {
    "path": "tinker_cookbook/completers.py",
    "content": "\"\"\"\nImplementations that correspond to a model or policy that can be sampled from, but with different amounts of additional structure.\n\nThe TokenCompleter operates on tokens. This is the version used by RL algorithms, because RL algorithms work on Tokens. The MessageCompleter operates on messages, so it needs to be used with a renderer.\n\nEvals and other code should use the appropriate interface.\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import TypeAlias\n\nimport tinker\n\nfrom tinker_cookbook import renderers\n\n# Interfaces\n\nStopCondition: TypeAlias = list[str] | list[int]\n\n\n@dataclass\nclass TokensWithLogprobs:\n    tokens: list[int]\n    maybe_logprobs: list[float] | None\n\n    @property\n    def logprobs(self) -> list[float]:\n        if self.maybe_logprobs is None:\n            raise ValueError(\"Logprobs are not available\")\n        return self.maybe_logprobs\n\n\nclass TokenCompleter:\n    async def __call__(\n        self, model_input: tinker.ModelInput, stop: StopCondition\n    ) -> TokensWithLogprobs:\n        raise NotImplementedError\n\n\nclass MessageCompleter:\n    # TODO maybe add n_samples to the interfaces?\n    async def __call__(self, messages: list[renderers.Message]) -> renderers.Message:\n        raise NotImplementedError\n\n\n# Implementations\n\n\n@dataclass\nclass TinkerTokenCompleter(TokenCompleter):\n    \"\"\"\n    The most standard TokenCompleter, which uses a tinker.SamplingClient to sample actions.\n    \"\"\"\n\n    sampling_client: tinker.SamplingClient\n    max_tokens: int\n    temperature: float = 1.0\n\n    async def __call__(\n        self, model_input: tinker.ModelInput, stop: StopCondition\n    ) -> TokensWithLogprobs:\n        \"\"\"Sample an action from the policy given an observation.\"\"\"\n        # Sample from the model\n        sample_result = await self.sampling_client.sample_async(\n            prompt=model_input,\n            num_samples=1,\n            sampling_params=tinker.SamplingParams(\n                stop=stop,\n                max_tokens=self.max_tokens,\n                temperature=self.temperature,\n            ),\n        )\n\n        # Extract tokens and logprobs from the first (and only) sample\n        sampled_tokens = sample_result.sequences[0].tokens\n        sampled_logprobs = sample_result.sequences[0].logprobs\n        assert sampled_logprobs is not None\n\n        return TokensWithLogprobs(tokens=sampled_tokens, maybe_logprobs=sampled_logprobs)\n\n\nclass TinkerMessageCompleter(MessageCompleter):\n    \"\"\"A completer that uses the actual model to generate responses.\"\"\"\n\n    def __init__(\n        self,\n        sampling_client: tinker.SamplingClient,\n        renderer: renderers.Renderer,\n        max_tokens: int,\n        stop_condition: StopCondition | None = None,\n        temperature: float = 1.0,\n    ):\n        self.sampling_client = sampling_client\n        self.renderer = renderer\n        self.max_tokens = max_tokens\n        self.temperature = temperature\n        if stop_condition is None:\n            self.stop_condition = self.renderer.get_stop_sequences()\n        else:\n            self.stop_condition = stop_condition\n\n    async def __call__(self, messages: list[renderers.Message]) -> renderers.Message:\n        # Render the conversation for the model\n        model_input = self.renderer.build_generation_prompt(messages)\n\n        # Sample from the model\n        response = await self.sampling_client.sample_async(\n            model_input,\n            num_samples=1,\n            sampling_params=tinker.SamplingParams(\n                temperature=self.temperature,\n                max_tokens=self.max_tokens,\n                stop=self.stop_condition,\n            ),\n        )\n\n        # Decode the response\n        parsed_message, _success = self.renderer.parse_response(response.sequences[0].tokens)\n\n        result: renderers.Message = {\"role\": \"assistant\", \"content\": parsed_message[\"content\"]}\n        if \"tool_calls\" in parsed_message:\n            result[\"tool_calls\"] = parsed_message[\"tool_calls\"]\n        return result\n"
  },
  {
    "path": "tinker_cookbook/display.py",
    "content": "import io\n\nimport tinker\nfrom termcolor import colored\n\nfrom tinker_cookbook.rl.types import Trajectory, Transition\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\nfrom tinker_cookbook.utils.format_colorized import format_colorized\n\n\ndef to_ints(chunk: tinker.ModelInputChunk, tokenizer: Tokenizer):\n    if isinstance(chunk, tinker.EncodedTextChunk):\n        return chunk.tokens\n    else:\n        (at_token,) = tokenizer.encode(\"@\", add_special_tokens=False)\n        return [at_token] * chunk.length\n\n\ndef colorize_example(datum: tinker.Datum, tokenizer: Tokenizer, key: str = \"weights\"):\n    int_tokens = [\n        token for chunk in datum.model_input.chunks for token in to_ints(chunk, tokenizer)\n    ] + [datum.loss_fn_inputs[\"target_tokens\"].tolist()[-1]]\n    weights = [0.0] + datum.loss_fn_inputs[key].tolist()\n    return format_colorized(int_tokens, weights, tokenizer)\n\n\ndef format_trajectory(\n    trajectory: Trajectory, tokenizer: Tokenizer, only_last_transition: bool = False\n) -> str:\n    buf = io.StringIO()\n\n    def colorize(s: str):\n        return colored(s, \"green\", attrs=[\"bold\"])\n\n    def bprint(s: str):\n        print(s, file=buf)\n\n    bprint(\"=\" * 60)\n    transitions: list[tuple[int, Transition]] = list(enumerate(trajectory.transitions))\n    if only_last_transition:\n        transitions = transitions[-1:]\n    for i, transition in transitions:\n        bprint(f\"------ Transition {i} ------\")\n        bprint(f\"{colorize('Observation:')}: {tokenizer.decode(transition.ob.to_ints())}\")\n        bprint(f\"{colorize('Action:')}: {tokenizer.decode(transition.ac.tokens)}\")\n        bprint(f\"{colorize('Reward:')}: {transition.reward}\")\n        bprint(f\"{colorize('Episode done:')}: {transition.episode_done}\")\n        bprint(f\"{colorize('Metrics:')}: {transition.metrics}\")\n        bprint(\"-\" * 60)\n    bprint(\"=\" * 60)\n    return buf.getvalue()\n"
  },
  {
    "path": "tinker_cookbook/distillation/__init__.py",
    "content": ""
  },
  {
    "path": "tinker_cookbook/distillation/datasets.py",
    "content": "\"\"\"\nDataset utilities for on-policy distillation.\n\nThis module contains dataset configuration classes and environment definitions\nfor distillation where the only supervision comes from the KL penalty against\na teacher model. The environment provides no correctness or format rewards.\n\"\"\"\n\nimport math\nfrom collections.abc import Sequence\nfrom functools import partial\nfrom typing import Literal\n\nimport chz\nimport tinker\nfrom datasets import load_dataset\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.exceptions import ConfigurationError, DataError\nfrom tinker_cookbook.rl.problem_env import ProblemEnv, ProblemGroupBuilder, logger\nfrom tinker_cookbook.rl.types import (\n    Action,\n    EnvGroupBuilder,\n    RLDataset,\n    RLDatasetBuilder,\n    StepResult,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\n@chz.chz\nclass TeacherConfig:\n    \"\"\"Configuration for a teacher model.\"\"\"\n\n    base_model: str\n    load_checkpoint_path: str | None = None\n\n\n@chz.chz\nclass DistillationDatasetConfig:\n    \"\"\"Configuration for a dataset used in distillation.\"\"\"\n\n    dataset_builder: RLDatasetBuilder\n    teacher_config: TeacherConfig\n    groups_per_batch: int\n\n\nclass CompositeDataset:\n    \"\"\"Wraps multiple datasets and samples from each according to their groups_per_batch.\"\"\"\n\n    def __init__(self, datasets: list[RLDataset], groups_per_batch_list: list[int]):\n        self.datasets = datasets\n        self.groups_per_batch_list = groups_per_batch_list\n        # Use the shortest dataset length\n        if len(datasets) > 0:\n            self.length = min(len(dataset) for dataset in datasets)\n        else:\n            self.length = 0\n\n    def __len__(self) -> int:\n        return self.length\n\n    def get_batch(self, i_batch: int) -> tuple[list[EnvGroupBuilder], list[int]]:\n        \"\"\"\n        Get a batch by sampling from each dataset according to groups_per_batch.\n\n        Returns:\n            env_group_builders: List of all env group builders\n            dataset_indices: List of dataset indices corresponding to each env group builder\n        \"\"\"\n        all_env_group_builders = []\n        all_dataset_indices = []\n\n        for dataset_idx, (dataset, _groups_per_batch) in enumerate(\n            zip(self.datasets, self.groups_per_batch_list)\n        ):\n            env_group_builders = dataset.get_batch(i_batch)\n            all_env_group_builders.extend(env_group_builders)\n            all_dataset_indices.extend([dataset_idx] * len(env_group_builders))\n\n        return all_env_group_builders, all_dataset_indices\n\n\nclass PromptOnlyEnv(ProblemEnv):\n    \"\"\"Environment that only provides prompts with no rewards.\"\"\"\n\n    def __init__(\n        self,\n        prompt: str,\n        renderer: renderers.Renderer,\n        convo_prefix: list[renderers.Message] | None = None,\n    ):\n        # Set format_coef to 0 since we don't care about format\n        super().__init__(renderer, convo_prefix, format_coef=0.0)\n        self.prompt = prompt\n\n    def get_question(self) -> str:\n        return self.prompt\n\n    def check_format(self, sample_str: str) -> bool:\n        # Always return True - no format checking for distillation\n        return True\n\n    def check_answer(self, sample_str: str) -> bool:\n        # Always return False - no answer checking for distillation\n        return False\n\n    def get_reference_answer(self) -> str:\n        \"\"\"No reference answer needed for distillation.\"\"\"\n        return \"\"\n\n    async def step(self, action: Action) -> StepResult:\n        \"\"\"Return zero reward always.\"\"\"\n        message, parse_success = self.renderer.parse_response(action)\n        return StepResult(\n            reward=0.0,\n            episode_done=True,\n            next_observation=tinker.ModelInput.empty(),\n            next_stop_condition=self.stop_condition,\n            metrics={},\n        )\n\n\nclass PromptOnlyDataset(RLDataset):\n    \"\"\"Dataset that provides prompts without rewards.\"\"\"\n\n    def __init__(\n        self,\n        prompts: list[str],\n        batch_size: int,\n        group_size: int,\n        renderer: renderers.Renderer,\n        tokenizer,\n        max_prompt_tokens: int | None = None,\n        convo_prefix: list[renderers.Message] | None = None,\n        dataset_name: str = \"prompts\",\n    ):\n        self.prompts = prompts\n        self.batch_size = batch_size\n        self.group_size = group_size\n        self.renderer = renderer\n        self.tokenizer = tokenizer\n        self.max_prompt_tokens = max_prompt_tokens\n        self.convo_prefix = convo_prefix\n        self.dataset_name = dataset_name\n\n    def _truncate_prompt(self, prompt: str) -> str:\n        \"\"\"Truncate prompt to max_prompt_tokens if specified.\"\"\"\n        if self.max_prompt_tokens is None:\n            return prompt\n\n        tokens = self.tokenizer.encode(prompt)\n        if len(tokens) > self.max_prompt_tokens:\n            tokens = tokens[: self.max_prompt_tokens]\n            return self.tokenizer.decode(tokens)\n        return prompt\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        batch_start = index * self.batch_size\n        batch_end = min((index + 1) * self.batch_size, len(self.prompts))\n        assert batch_start < batch_end, \"Incorrect batch size\"\n        return [\n            ProblemGroupBuilder(\n                env_thunk=partial(\n                    PromptOnlyEnv,\n                    self._truncate_prompt(prompt),\n                    self.renderer,\n                    convo_prefix=self.convo_prefix,\n                ),\n                num_envs=self.group_size,\n                dataset_name=self.dataset_name,\n            )\n            for prompt in self.prompts[batch_start:batch_end]\n        ]\n\n    def __len__(self) -> int:\n        return math.ceil(len(self.prompts) / self.batch_size)\n\n\ndef load_deepmath_prompts(split: Literal[\"train\", \"test\"] = \"train\") -> list[str] | None:\n    \"\"\"Load DeepMath prompts from HuggingFace. Returns None if split doesn't exist.\"\"\"\n    try:\n        ds = load_dataset(\"zwhe99/DeepMath-103K\", split=split)\n        # DeepMath has 'question' field containing the math problem\n        prompts = [row[\"question\"] for row in ds]  # type: ignore\n        return prompts\n    except Exception as e:\n        logger.warning(f\"Could not load {split} split for DeepMath: {e}\")\n        return None\n\n\ndef load_tulu3_prompts() -> list[str] | None:\n    \"\"\"\n    Load Tulu3 prompts from HuggingFace.\n\n    Extracts the first user message from each conversation.\n    Returns None if dataset cannot be loaded.\n    \"\"\"\n    try:\n        ds = load_dataset(\"allenai/tulu-3-sft-mixture\", split=\"train\")\n        prompts = []\n\n        for row in ds:  # type: ignore\n            messages = row[\"messages\"]  # type: ignore\n            # Extract first user message\n            first_user_msg = None\n            for msg in messages:\n                if msg[\"role\"] == \"user\":  # type: ignore\n                    first_user_msg = msg[\"content\"]  # type: ignore\n                    break\n\n            if first_user_msg:\n                prompts.append(first_user_msg)\n\n        return prompts\n    except Exception as e:\n        logger.warning(f\"Could not load Tulu3 dataset: {e}\")\n        return None\n\n\n@chz.chz\nclass PromptOnlyDatasetBuilder(RLDatasetBuilder):\n    \"\"\"Builder for prompt-only datasets.\"\"\"\n\n    dataset_name: str  # e.g., \"deepmath\"\n    groups_per_batch: int\n    group_size: int\n    model_name_for_tokenizer: str\n    renderer_name: str\n    convo_prefix: list[renderers.Message] | None = None\n    max_prompt_tokens: int | None = 1024  # Maximum tokens per prompt (None = no truncation)\n\n    async def __call__(self) -> tuple[PromptOnlyDataset, PromptOnlyDataset | None]:\n        tokenizer = get_tokenizer(self.model_name_for_tokenizer)\n        renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer)\n\n        # Load prompts based on dataset name\n        if self.dataset_name == \"deepmath\":\n            train_prompts = load_deepmath_prompts(\"train\")\n            test_prompts = None\n        elif self.dataset_name == \"tulu3\":\n            train_prompts = load_tulu3_prompts()\n            test_prompts = None  # Tulu3 only has train split\n        else:\n            raise ConfigurationError(f\"Unknown dataset: {self.dataset_name}\")\n\n        if train_prompts is None:\n            raise DataError(f\"Could not load train split for {self.dataset_name}\")\n\n        train_dataset = PromptOnlyDataset(\n            prompts=train_prompts,\n            batch_size=self.groups_per_batch,\n            group_size=self.group_size,\n            renderer=renderer,\n            tokenizer=tokenizer,\n            max_prompt_tokens=self.max_prompt_tokens,\n            convo_prefix=self.convo_prefix,\n            dataset_name=self.dataset_name,\n        )\n\n        test_dataset = (\n            PromptOnlyDataset(\n                prompts=test_prompts,\n                batch_size=self.groups_per_batch,\n                group_size=1,  # Use group_size=1 for test\n                renderer=renderer,\n                tokenizer=tokenizer,\n                max_prompt_tokens=self.max_prompt_tokens,\n                convo_prefix=self.convo_prefix,\n                dataset_name=f\"{self.dataset_name}_test\",\n            )\n            if test_prompts is not None\n            else None\n        )\n\n        return train_dataset, test_dataset\n"
  },
  {
    "path": "tinker_cookbook/distillation/train_on_policy.py",
    "content": "\"\"\"\nImplements on-policy distillation. For more details, see:\nhttps://thinkingmachines.ai/blog/on-policy-distillation\n\"\"\"\n\nimport asyncio\nimport logging\nfrom collections.abc import Sequence\nfrom pathlib import Path\nfrom typing import Any, cast\n\nimport chz\nimport tinker\nimport torch\nfrom tinker.types import LossFnType\n\nfrom tinker_cookbook import checkpoint_utils, model_info\nfrom tinker_cookbook.display import colorize_example\nfrom tinker_cookbook.distillation.datasets import (\n    CompositeDataset,\n    DistillationDatasetConfig,\n)\nfrom tinker_cookbook.eval.evaluators import SamplingClientEvaluator, SamplingClientEvaluatorBuilder\nfrom tinker_cookbook.exceptions import ConfigurationError\nfrom tinker_cookbook.rl.data_processing import (\n    assemble_training_data,\n    compute_advantages,\n)\nfrom tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics\nfrom tinker_cookbook.rl.metrics import discounted_future_sum_vectorized\nfrom tinker_cookbook.rl.train import (\n    compute_full_batch_metrics_and_get_sampling_client,\n    do_group_rollout_and_filter_constant_reward,\n    save_checkpoint_and_get_sampling_client,\n    train_step,\n)\nfrom tinker_cookbook.rl.types import (\n    EnvGroupBuilder,\n    TrajectoryGroup,\n)\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\nfrom tinker_cookbook.utils import ml_log, trace\nfrom tinker_cookbook.utils.deprecation import warn_deprecated\nfrom tinker_cookbook.utils.misc_utils import safezip\n\nlogger = logging.getLogger(__name__)\n\n\n@trace.scope\nasync def incorporate_kl_penalty(\n    data_D: list[tinker.Datum],\n    teacher_clients_D: list[tinker.SamplingClient],\n    dataset_indices_D: list[int],\n    kl_penalty_coef: float,\n    kl_discount_factor: float,\n) -> dict[str, float]:\n    \"\"\"\n    Compute reverse KL between the student (log p) and the teacher model (log q), computed as\n    log p - log q. We then adjust the advantages in-place as the negative reverse KL.\n\n    Args:\n        data_D: List of datums to compute KL for\n        teacher_clients_D: List of teacher sampling clients, one per datum\n        dataset_indices_D: List of dataset indices, one per datum\n        kl_penalty_coef: Coefficient for KL penalty\n        kl_discount_factor: Discount factor for future KL\n    \"\"\"\n    # Note: if your teacher has a different renderer than the student, you may want to modify\n    #       the full_sequence_inputs_D to match the teacher's renderer.\n    full_sequence_inputs_D = [\n        datum.model_input.append_int(cast(int, datum.loss_fn_inputs[\"target_tokens\"].data[-1]))\n        for datum in data_D\n    ]\n    # Compute the teacher's logprobs for each element of the batch\n    # Each datum uses its corresponding teacher sampling client\n    teacher_logprobs_D = await asyncio.gather(\n        *[\n            teacher_client.compute_logprobs_async(sequence_input)\n            for teacher_client, sequence_input in zip(teacher_clients_D, full_sequence_inputs_D)\n        ]\n    )\n    # The reverse KL is computed as KL[p||q] = log p - log q, where\n    #   - p: sampled_logprobs\n    #   - q: teacher_logprobs\n    sampled_logprobs_D = [datum.loss_fn_inputs[\"logprobs\"].to_torch() for datum in data_D]\n    float_masks = [datum.loss_fn_inputs[\"mask\"].to_torch().float() for datum in data_D]\n    reverse_kl = [\n        (sampled_logprobs - torch.tensor(teacher_logprobs[1:])) * mask\n        for teacher_logprobs, sampled_logprobs, mask in safezip(\n            teacher_logprobs_D, sampled_logprobs_D, float_masks\n        )\n    ]\n    # Track per-dataset KL for logging\n    # dataset_idx -> (sum of KL, sum of mask)\n    per_dataset_kl: dict[int, tuple[float, float]] = {}\n\n    for i, datum in enumerate(data_D):\n        # The advantage is the negative reverse KL. We can optionally apply a discount factor.\n        kl_advantages = -kl_penalty_coef * float_masks[i] * reverse_kl[i]\n        if kl_discount_factor > 0:\n            kl_advantages = discounted_future_sum_vectorized(kl_advantages, kl_discount_factor)\n        datum.loss_fn_inputs[\"advantages\"] = tinker.TensorData.from_torch(\n            datum.loss_fn_inputs[\"advantages\"].to_torch() + kl_advantages\n        )\n\n        # Accumulate per-dataset KL\n        dataset_idx = dataset_indices_D[i]\n        kl_sum = reverse_kl[i].sum().item()\n        mask_sum = float_masks[i].sum().item()\n        if dataset_idx not in per_dataset_kl:\n            per_dataset_kl[dataset_idx] = (0.0, 0.0)\n        prev_kl_sum, prev_mask_sum = per_dataset_kl[dataset_idx]\n        per_dataset_kl[dataset_idx] = (prev_kl_sum + kl_sum, prev_mask_sum + mask_sum)\n\n    # Compute average reverse KL over the batch for logging purposes\n    avg_logp_diff = sum([diff.sum() for diff in reverse_kl]) / sum(\n        [mask.sum() for mask in float_masks]\n    )\n\n    # Compute per-dataset metrics\n    metrics = {\"teacher_kl\": float(avg_logp_diff)}\n    for dataset_idx, (kl_sum, mask_sum) in per_dataset_kl.items():\n        if mask_sum > 0:\n            metrics[f\"teacher_kl/dataset_{dataset_idx}\"] = float(kl_sum / mask_sum)\n\n    return metrics\n\n\n@chz.chz\nclass Config:\n    learning_rate: float\n    dataset_configs: list[DistillationDatasetConfig]\n    model_name: str\n    renderer_name: str | None = None\n    max_tokens: int\n    temperature: float = 1.0\n    compute_post_kl: bool = False\n    evaluator_builders: list[SamplingClientEvaluatorBuilder] = chz.field(default_factory=list)\n    lora_rank: int = 32\n\n    kl_penalty_coef: float = 1.0\n    kl_discount_factor: float = 0.0\n\n    # Loss function and configuration.\n    # See https://tinker-docs.thinkingmachines.ai/losses\n    loss_fn: LossFnType = \"importance_sampling\"\n    loss_fn_config: dict[str, Any] | None = None\n\n    # Number of optimizer steps per training iteration.\n    # Useful for very large batch sizes.\n    num_substeps: int = 1\n\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    log_path: str = chz.field(munger=lambda _, s: str(Path(s).expanduser()))\n    base_url: str | None = None\n    enable_trace: bool = False\n    span_chart_every: int = 0\n\n    eval_every: int = 20\n    save_every: int = 20\n    load_checkpoint_path: str | None = None\n\n    # Maximum number of training steps. If None, train on the full dataset.\n    max_steps: int | None = None\n    # Deprecated alias for max_steps. Use max_steps instead.\n    max_step: int | None = None\n\n\n@trace.scope\nasync def prepare_minibatch(\n    env_group_builders_P: Sequence[EnvGroupBuilder],\n    trajectory_groups_P: list[TrajectoryGroup],\n    tokenizer: Tokenizer,\n    dataset_indices_P: list[int],\n    teacher_clients: list[tinker.SamplingClient],\n    kl_penalty_coef: float,\n    kl_discount_factor: float,\n) -> tuple[list[tinker.Datum], dict[str, Any]]:\n    \"\"\"Converts the trajectories into a minibatch, and provides metrics about the minibatch\"\"\"\n\n    # Compute trajectory metrics\n    metrics = {}\n    taglist_P = [env_group_builder.logging_tags() for env_group_builder in env_group_builders_P]\n    metrics.update(compute_trajectory_metrics(trajectory_groups_P, taglist_P))\n\n    # Assemble training data\n    async with trace.scope_span(\"assemble_training_data\"):\n        advantages_P = compute_advantages(trajectory_groups_P)\n        data_D, metadata_D = assemble_training_data(trajectory_groups_P, advantages_P)\n\n    # Print one datum per dataset\n    printed_datasets = set()\n    for datum, metadata in zip(data_D, metadata_D):\n        dataset_idx = dataset_indices_P[metadata[\"group_idx\"]]\n        if dataset_idx not in printed_datasets:\n            logger.info(colorize_example(datum, tokenizer, key=\"mask\"))\n            printed_datasets.add(dataset_idx)\n\n    # Incorporate KL penalty if configured\n    if kl_penalty_coef > 0:\n        async with trace.scope_span(\"compute_kl_penalty\"):\n            # Map each datum to its teacher sampling client and dataset index using metadata\n            #   - metadata_D contains group_idx which indexes into trajectory_groups_P\n            #   - dataset_indices_P[group_idx] gives us the dataset index\n            #   - teacher_clients[dataset_idx] gives us the teacher\n            teacher_clients_D = [\n                teacher_clients[dataset_indices_P[metadata[\"group_idx\"]]] for metadata in metadata_D\n            ]\n            dataset_indices_D = [\n                dataset_indices_P[metadata[\"group_idx\"]] for metadata in metadata_D\n            ]\n            kl_penalty_metrics = await incorporate_kl_penalty(\n                data_D,\n                teacher_clients_D,\n                dataset_indices_D,\n                kl_penalty_coef,\n                kl_discount_factor,\n            )\n        metrics.update(kl_penalty_metrics)\n\n    return data_D, metrics\n\n\n@trace.scope\nasync def do_train_step_and_get_sampling_client(\n    cfg: Config,\n    i_batch: int,\n    training_client: tinker.TrainingClient,\n    service_client: tinker.ServiceClient,\n    tokenizer: Tokenizer,\n    env_group_builders_P: Sequence[EnvGroupBuilder],\n    trajectory_groups_P: list[TrajectoryGroup],\n    dataset_indices_P: list[int],\n    teacher_clients: list[tinker.SamplingClient],\n) -> tuple[tinker.SamplingClient, dict[str, Any]]:\n    trace.update_scope_context({\"step\": i_batch})\n\n    metrics = {}\n    data_D, prepare_minibatch_metrics = await prepare_minibatch(\n        env_group_builders_P,\n        trajectory_groups_P,\n        tokenizer,\n        dataset_indices_P,\n        teacher_clients,\n        kl_penalty_coef=cfg.kl_penalty_coef,\n        kl_discount_factor=cfg.kl_discount_factor,\n    )\n    metrics.update(prepare_minibatch_metrics)\n\n    async with trace.scope_span(\"train\"):\n        training_logprobs_D = await train_step(\n            data_D=data_D,\n            training_client=training_client,\n            learning_rate=cfg.learning_rate,\n            num_substeps=cfg.num_substeps,\n            loss_fn=cfg.loss_fn,\n            loss_fn_config=cfg.loss_fn_config,\n            metrics=metrics,\n        )\n\n    sampling_client, full_batch_metrics = await compute_full_batch_metrics_and_get_sampling_client(\n        training_client,\n        # NOTE: saving the checkpoint as the i + 1 step\n        i_batch + 1,\n        data_D,\n        training_logprobs_D,\n        cfg.log_path,\n        cfg.save_every,\n        cfg.compute_post_kl,\n    )\n    metrics.update(full_batch_metrics)\n\n    return sampling_client, metrics\n\n\n@trace.scope\nasync def do_sync_training(\n    start_batch: int,\n    end_batch: int,\n    num_batches: int,\n    cfg: Config,\n    training_client: tinker.TrainingClient,\n    service_client: tinker.ServiceClient,\n    evaluators: list[SamplingClientEvaluator],\n    dataset: CompositeDataset,\n    teacher_clients: list[tinker.SamplingClient],\n    ml_logger: ml_log.Logger,\n    tokenizer: Tokenizer,\n):\n    \"\"\"Implements fully synchronous on-policy training\"\"\"\n\n    # Initial sampling client\n    sampling_client, _ = await save_checkpoint_and_get_sampling_client(\n        training_client, start_batch, cfg.log_path, cfg.save_every\n    )\n\n    log_path = Path(cfg.log_path)\n\n    for i_batch in range(start_batch, end_batch):\n        metrics = {\n            \"progress/batch\": i_batch,\n            \"optim/lr\": cfg.learning_rate,\n            \"progress/done_frac\": (i_batch + 1) / num_batches,\n        }\n\n        with trace.trace_iteration(step=i_batch) as window:\n            # Run evaluations\n            if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0:\n                async with trace.scope_span(\"run_evals\"):\n                    for evaluator in evaluators:\n                        eval_metrics = await evaluator(sampling_client)\n                        metrics.update({f\"test/{k}\": v for k, v in eval_metrics.items()})\n\n            # Get batch and sample trajectories\n            env_group_builders_P, dataset_indices_P = dataset.get_batch(i_batch)\n            async with trace.scope_span(\"sample\"):\n                trajectory_groups_P = await asyncio.gather(\n                    *[\n                        asyncio.create_task(\n                            do_group_rollout_and_filter_constant_reward(\n                                sampling_client,\n                                builder,\n                                temperature=cfg.temperature,\n                                max_tokens=cfg.max_tokens,\n                                do_remove_constant_reward_groups=False,\n                            ),\n                            name=f\"sample_task_{i}\",\n                        )\n                        for i, builder in enumerate(env_group_builders_P)\n                    ],\n                )\n            trajectory_groups_P = [\n                trajectory_group\n                for trajectory_group in trajectory_groups_P\n                if trajectory_group is not None\n            ]\n\n            # Train step\n            sampling_client, train_step_metrics = await do_train_step_and_get_sampling_client(\n                cfg,\n                i_batch,\n                training_client,\n                service_client,\n                tokenizer,\n                env_group_builders_P,\n                trajectory_groups_P,\n                dataset_indices_P,\n                teacher_clients,\n            )\n\n            metrics.update(train_step_metrics)\n\n        # Log timing metrics from trace_iteration window\n        metrics.update(window.get_timing_metrics())\n        window.write_spans_jsonl(log_path / \"timing_spans.jsonl\", step=i_batch)\n        if cfg.span_chart_every > 0 and i_batch % cfg.span_chart_every == 0:\n            trace.save_gantt_chart_html(\n                window, i_batch, log_path / f\"timing_gantt_{i_batch:06d}.html\"\n            )\n        ml_logger.log_metrics(metrics, step=i_batch)\n\n\n@trace.scope\nasync def main(\n    cfg: Config,\n):\n    \"\"\"Main training loop for on-policy distillation.\"\"\"\n    ml_logger = ml_log.setup_logging(\n        log_dir=cfg.log_path,\n        wandb_project=cfg.wandb_project,\n        config=cfg,\n        wandb_name=cfg.wandb_name,\n    )\n    if cfg.enable_trace:\n        # Get and rename the current (main) task\n        current_task = asyncio.current_task()\n        if current_task is not None:\n            current_task.set_name(\"main\")\n        trace_events_path = str(Path(cfg.log_path) / \"trace_events.jsonl\")\n        logger.info(f\"Tracing is enabled. Trace events will be saved to {trace_events_path}\")\n        logger.info(\n            f\"Run `python tinker_cookbook/utils/trace.py {trace_events_path} trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/\"\n        )\n        trace.trace_init(output_file=trace_events_path)\n\n    logging.getLogger(\"httpx\").setLevel(logging.WARNING)\n    logging.getLogger(\"pylatexenc\").setLevel(logging.WARNING)\n\n    resume_info = checkpoint_utils.get_last_checkpoint(cfg.log_path)\n    if resume_info:\n        start_batch = resume_info.batch\n    else:\n        start_batch = 0\n\n    service_client = tinker.ServiceClient(base_url=cfg.base_url)\n    user_metadata: dict[str, str] = {}\n    if wandb_link := ml_logger.get_logger_url():\n        user_metadata[\"wandb_link\"] = wandb_link\n    checkpoint_utils.add_renderer_name_to_user_metadata(user_metadata, cfg.renderer_name)\n    model_info.warn_if_renderer_not_recommended(cfg.model_name, cfg.renderer_name)\n\n    if resume_info:\n        # Resuming interrupted training - load optimizer state for proper continuation\n        await checkpoint_utils.check_renderer_name_for_checkpoint_async(\n            service_client, resume_info.state_path, cfg.renderer_name\n        )\n        training_client = (\n            await service_client.create_training_client_from_state_with_optimizer_async(\n                resume_info.state_path, user_metadata=user_metadata\n            )\n        )\n        logger.info(f\"Resumed training from {resume_info.state_path}\")\n    elif cfg.load_checkpoint_path:\n        # Starting fresh from a checkpoint - load weights only (fresh optimizer)\n        await checkpoint_utils.check_renderer_name_for_checkpoint_async(\n            service_client, cfg.load_checkpoint_path, cfg.renderer_name\n        )\n        training_client = await service_client.create_training_client_from_state_async(\n            cfg.load_checkpoint_path, user_metadata=user_metadata\n        )\n        logger.info(f\"Loaded weights from {cfg.load_checkpoint_path}\")\n    else:\n        training_client = await service_client.create_lora_training_client_async(\n            cfg.model_name, rank=cfg.lora_rank, user_metadata=user_metadata\n        )\n\n    # Get tokenizer from training client\n    tokenizer = training_client.get_tokenizer()\n\n    # Create datasets and teacher sampling clients from configs\n    datasets = []\n    teacher_clients = []\n    groups_per_batch_list = []\n    evaluators = [evaluator() for evaluator in cfg.evaluator_builders]\n\n    for dataset_config in cfg.dataset_configs:\n        # Create dataset\n        dataset, maybe_test_dataset = await dataset_config.dataset_builder()\n        datasets.append(dataset)\n        groups_per_batch_list.append(dataset_config.groups_per_batch)\n\n        # Add test dataset evaluator if present\n        if maybe_test_dataset is not None:\n            evaluators.append(RLTestSetEvaluator(maybe_test_dataset, max_tokens=cfg.max_tokens))\n\n        # Create teacher sampling client\n        teacher_config = dataset_config.teacher_config\n        teacher_client = service_client.create_sampling_client(base_model=teacher_config.base_model)\n        # Load teacher checkpoint if specified\n        if teacher_config.load_checkpoint_path is not None:\n            teacher_client = service_client.create_sampling_client(\n                base_model=teacher_config.base_model,\n                model_path=teacher_config.load_checkpoint_path,\n            )\n        teacher_clients.append(teacher_client)\n        logger.info(\n            f\"Created teacher sampling client for {teacher_config.base_model} \"\n            f\"(checkpoint: {teacher_config.load_checkpoint_path})\"\n        )\n\n    # Wrap datasets in CompositeDataset\n    composite_dataset = CompositeDataset(datasets, groups_per_batch_list)\n    num_batches = len(composite_dataset)\n    # Resolve max_steps from either max_steps or deprecated max_step\n    effective_max_steps = cfg.max_steps\n    if cfg.max_step is not None:\n        if cfg.max_steps is not None:\n            raise ConfigurationError(\"Cannot specify both max_steps and max_step. Use max_steps.\")\n        warn_deprecated(\"max_step\", removal_version=\"0.3.0\", message=\"Use 'max_steps' instead.\")\n        effective_max_steps = cfg.max_step\n    num_batches = (\n        min(effective_max_steps, num_batches) if effective_max_steps is not None else num_batches\n    )\n    logger.info(f\"Will train on {num_batches} batches (dataset has {num_batches})\")\n\n    # Training loop\n    await do_sync_training(\n        start_batch=start_batch,\n        end_batch=num_batches,\n        num_batches=num_batches,\n        cfg=cfg,\n        training_client=training_client,\n        service_client=service_client,\n        evaluators=evaluators,\n        dataset=composite_dataset,\n        teacher_clients=teacher_clients,\n        ml_logger=ml_logger,\n        tokenizer=tokenizer,\n    )\n\n    # Save final checkpoint\n    if start_batch < num_batches:\n        _ = await checkpoint_utils.save_checkpoint_async(\n            training_client=training_client,\n            name=\"final\",\n            log_path=cfg.log_path,\n            kind=\"both\",\n            loop_state={\"batch\": num_batches},\n            ttl_seconds=None,\n        )\n    else:\n        logger.info(\"Training was already complete; nothing to do\")\n\n    # Cleanup\n    ml_logger.close()\n    logger.info(\"Training completed successfully\")\n"
  },
  {
    "path": "tinker_cookbook/eval/README.md",
    "content": "Tinker can integrate with oss eval framework like Inspect AI easily (`run_inspect_evals.py`), or you can create a simple evaluator yourself ([`custom_evaluators.py`](./custom_evaluators.py)). Check out our [docs](https://tinker-docs.thinkingmachines.ai/evals).\n"
  },
  {
    "path": "tinker_cookbook/eval/__init__.py",
    "content": ""
  },
  {
    "path": "tinker_cookbook/eval/custom_evaluators.py",
    "content": "import asyncio\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport tinker\nfrom tinker import types\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.eval.evaluators import SamplingClientEvaluator\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\nclass CustomEvaluator(SamplingClientEvaluator):\n    \"\"\"\n    A toy SamplingClientEvaluator that runs a custom evaluation and returns its metrics.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset: Any,\n        grader_fn: Callable[[str, str], bool],\n        model_name: str,\n        renderer_name: str,\n    ):\n        \"\"\"\n        Initialize the CustomEvaluator.\n        Args:\n            config: Configuration object containing all evaluation parameters\n        \"\"\"\n        self.dataset = dataset\n        self.grader_fn = grader_fn\n\n        tokenizer = get_tokenizer(model_name)\n        self.renderer = renderers.get_renderer(name=renderer_name, tokenizer=tokenizer)\n\n    async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:\n        \"\"\"\n        Run custom evaluation on the given sampling client and return metrics.\n        Args:\n            sampling_client: The sampling client to evaluate\n        Returns:\n            Dictionary of metrics from inspect evaluation\n        \"\"\"\n\n        metrics = {}\n\n        num_examples = len(self.dataset)\n        num_correct = 0\n\n        sampling_params = types.SamplingParams(\n            max_tokens=100,\n            temperature=0.7,\n            top_p=1.0,\n            stop=self.renderer.get_stop_sequences(),\n        )\n\n        for datum in self.dataset:\n            model_input: types.ModelInput = self.renderer.build_generation_prompt(\n                [renderers.Message(role=\"user\", content=datum[\"input\"])]\n            )\n            # Generate response\n            r: types.SampleResponse = await sampling_client.sample_async(\n                prompt=model_input, num_samples=1, sampling_params=sampling_params\n            )\n            tokens: list[int] = r.sequences[0].tokens\n            response: renderers.Message = self.renderer.parse_response(tokens)[0]\n            content = renderers.get_text_content(response)\n            if self.grader_fn(content, datum[\"output\"]):\n                num_correct += 1\n\n        metrics[\"accuracy\"] = num_correct / num_examples\n        return metrics\n\n\nQA_DATASET = [\n    {\"input\": \"What is the capital of France?\", \"output\": \"Paris\"},\n    {\"input\": \"What is the capital of Germany?\", \"output\": \"Berlin\"},\n    {\"input\": \"What is the capital of Italy?\", \"output\": \"Rome\"},\n]\n\n\ndef grader_fn(response: str, target: str) -> bool:\n    return target.lower() in response.lower()\n\n\nevaluator = CustomEvaluator(\n    dataset=QA_DATASET,\n    grader_fn=grader_fn,\n    renderer_name=\"llama3\",\n    model_name=\"meta-llama/Llama-3.1-8B-Instruct\",\n)\n\nservice_client = tinker.ServiceClient()\nsampling_client = service_client.create_sampling_client(\n    base_model=\"meta-llama/Llama-3.1-8B-Instruct\"\n)\n\n\nasync def main():\n    result = await evaluator(sampling_client)\n    print(result)\n\n\nasyncio.run(main())\n"
  },
  {
    "path": "tinker_cookbook/eval/custom_inspect_task.py",
    "content": "\"\"\"\nExample of using LLM-as-a-judge with inspect_ai.\n\nTo run this task, use:\npython -m tinker_cookbook.eval.run_inspect_evals \\\n    model_path=tinker://your-model-path \\\n    tasks=tinker_cookbook.eval.custom_inspect_task:example_lm_as_judge \\\n    renderer_name=role_colon \\\n    model_name=Qwen/Qwen3-8B-Base\n\"\"\"\n\nimport tinker\nfrom inspect_ai import Task, task\nfrom inspect_ai.dataset import MemoryDataset, Sample\nfrom inspect_ai.model import GenerateConfig as InspectAIGenerateConfig\nfrom inspect_ai.model import Model as InspectAIModel\nfrom inspect_ai.scorer import model_graded_qa\nfrom inspect_ai.solver import generate\n\nfrom tinker_cookbook.eval.inspect_utils import InspectAPIFromTinkerSampling\n\nQA_DATASET = MemoryDataset(\n    name=\"qa_dataset\",\n    samples=[\n        Sample(\n            input=\"What is the capital of France?\",\n            target=\"Paris\",\n        ),\n        Sample(\n            input=\"What is the capital of Italy?\",\n            target=\"Rome\",\n        ),\n    ],\n)\n\nservice_client = tinker.ServiceClient()\nsampling_client = service_client.create_sampling_client(\n    base_model=\"meta-llama/Llama-3.1-8B-Instruct\"\n)\n\napi = InspectAPIFromTinkerSampling(\n    renderer_name=\"llama3\",  # pyright: ignore[reportCallIssue]\n    model_name=\"meta-llama/Llama-3.1-8B-Instruct\",\n    sampling_client=sampling_client,  # pyright: ignore[reportCallIssue]\n    verbose=False,  # pyright: ignore[reportCallIssue]\n)\n\nGRADER_MODEL = InspectAIModel(api=api, config=InspectAIGenerateConfig())\n\n\n@task\ndef example_lm_as_judge() -> Task:\n    \"\"\"\n    Example task using LLM-as-a-judge scoring.\n\n    Note: The grader model defaults to the model being evaluated.\n    To use a different grader model, specify it with --model-grader when using inspect directly.\n    \"\"\"\n    return Task(\n        name=\"llm_as_judge\",\n        dataset=QA_DATASET,\n        solver=generate(),\n        scorer=model_graded_qa(\n            instructions=\"Grade strictly against the target text as general answer key and rubric. \"\n            \"Respond 'GRADE: C' if correct or 'GRADE: I' otherwise.\",\n            partial_credit=False,\n            # model parameter is optional - if not specified, uses the model being evaluated\n            model=GRADER_MODEL,\n        ),\n    )\n"
  },
  {
    "path": "tinker_cookbook/eval/evaluators.py",
    "content": "import logging\nfrom collections.abc import Callable\n\nimport tinker\n\n# Set up logger\nlogger = logging.getLogger(__name__)\n\n\nclass TrainingClientEvaluator:\n    \"\"\"\n    An evaluator that takes in a TrainingClient\n    \"\"\"\n\n    async def __call__(self, training_client: tinker.TrainingClient) -> dict[str, float]:\n        raise NotImplementedError\n\n\nclass SamplingClientEvaluator:\n    \"\"\"\n    An evaluator that takes in a TokenCompleter\n    \"\"\"\n\n    async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:\n        raise NotImplementedError\n\n\nEvaluatorBuilder = Callable[[], TrainingClientEvaluator | SamplingClientEvaluator]\nSamplingClientEvaluatorBuilder = Callable[[], SamplingClientEvaluator]\nEvaluator = TrainingClientEvaluator | SamplingClientEvaluator\n"
  },
  {
    "path": "tinker_cookbook/eval/inspect_evaluators.py",
    "content": "import logging\nfrom pathlib import Path\n\nimport chz\nimport tinker\nfrom inspect_ai import Tasks, eval_async\nfrom inspect_ai.model import GenerateConfig as InspectAIGenerateConfig\nfrom inspect_ai.model import Model as InspectAIModel\n\nfrom tinker_cookbook.eval.evaluators import SamplingClientEvaluator\nfrom tinker_cookbook.eval.inspect_utils import InspectAPIFromTinkerSampling\nfrom tinker_cookbook.exceptions import ConfigurationError\n\n# Set up logger\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass InspectEvaluatorBuilder:\n    \"\"\"\n    Configuration for inspect evaluation.\n    This class provides a structured way to configure inspect evaluation\n    parameters that can be used both in training configs and evaluator builders.\n    \"\"\"\n\n    # Required parameters\n    tasks: Tasks\n    renderer_name: str | None = None\n    # TODO: remove model_name once the SDK adds a get_tokenizer method to sampling client\n    model_name: str | None = None\n    # Random seed for sampling. If None, sampling is non-deterministic.\n    seed: int | None = None\n    # If True, logs prompts and responses to the console (useful for debugging).\n    verbose: bool = False\n    # When True, model reasoning/thinking is preserved in inspect output as ContentReasoning.\n    # When False (default), reasoning is stripped and only text content is returned.\n    include_reasoning: bool = False\n\n    # Generation parameters\n    temperature: float = 1.0\n    max_tokens: int = 1000\n    top_p: float = 1.0\n    # Top-k sampling. -1 disables top-k filtering (uses all tokens).\n    top_k: int = -1\n    # Number of independent responses to generate per prompt. Used for majority\n    # voting or best-of-n evaluation strategies.\n    num_choices: int = 1\n\n    # Evaluation parameters\n    # Maximum number of samples to evaluate. If None, evaluates all samples.\n    limit: int | None = None\n    debug_errors: bool = True\n    log_dir: str | None = None\n    # Maximum concurrent sampling requests to Tinker.\n    max_connections: int = 512\n    log_level: str = \"INFO\"\n    # Metadata to associate with this evaluation run (visible in inspect logs)\n    metadata: dict[str, str] | None = None\n\n    def __call__(self) -> SamplingClientEvaluator:\n        return InspectEvaluator(self)\n\n\nclass InspectEvaluator(SamplingClientEvaluator):\n    \"\"\"\n    A SamplingClientEvaluator that runs inspect tasks and returns their metrics.\n    \"\"\"\n\n    def __init__(self, config: InspectEvaluatorBuilder):\n        \"\"\"\n        Initialize the InspectEvaluator.\n        Args:\n            config: Configuration object containing all evaluation parameters\n        \"\"\"\n        self.config = config\n\n    async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:\n        \"\"\"\n        Run inspect evaluation on the given sampling client and return metrics.\n        Args:\n            sampling_client: The sampling client to evaluate\n        Returns:\n            Dictionary of metrics from inspect evaluation\n        \"\"\"\n        if self.config.model_name is None:\n            raise ConfigurationError(\"model_name must be set before running evaluation\")\n        if self.config.renderer_name is None:\n            raise ConfigurationError(\"renderer_name must be set before running evaluation\")\n        # Create the inspect API wrapper\n        api = InspectAPIFromTinkerSampling(\n            renderer_name=self.config.renderer_name,\n            model_name=self.config.model_name,\n            sampling_client=sampling_client,\n            verbose=self.config.verbose,\n            include_reasoning=self.config.include_reasoning,\n        )\n        # Create the inspect model\n        model = InspectAIModel(\n            api=api,\n            config=InspectAIGenerateConfig(\n                temperature=self.config.temperature,\n                max_tokens=self.config.max_tokens,\n                top_p=self.config.top_p,\n                top_k=self.config.top_k,\n                seed=self.config.seed,\n                num_choices=self.config.num_choices,\n            ),\n        )\n\n        # Run evaluation\n        results = await eval_async(\n            tasks=self.config.tasks,\n            model=[model],\n            limit=self.config.limit,\n            debug_errors=self.config.debug_errors,\n            # Never retry - the tinker SDK is doing this for us already\n            retry_on_error=0,\n            # Although Tinker sampling tries very hard to only throw unrecoverable failures,\n            # the inspect evaluation can still fail if e.g. the parser returns an error for\n            # a given sample.\n            fail_on_error=False,\n            log_dir=self.config.log_dir or str(Path(\"~/inspect-logs\").expanduser()),\n            max_connections=self.config.max_connections,\n            log_level=self.config.log_level,\n            log_realtime=False,\n            log_buffer=1000,\n            metadata=self.config.metadata,\n        )\n\n        # Extract metrics from results\n        metrics = {}\n        for task_result in results:\n            if task_result.results is not None and task_result.results.scores is not None:\n                for task_name, score in task_result.results.scores[0].metrics.items():\n                    if task_result.eval.dataset is not None:\n                        dataset_name = task_result.eval.dataset.name\n                    else:\n                        dataset_name = \"unknown\"\n                    metrics[dataset_name + \"/\" + task_name] = score.value  # pyright: ignore[reportOptionalOperand]\n\n        logger.info(f\"Inspect evaluation completed. Metrics: {metrics}\")\n        return metrics\n"
  },
  {
    "path": "tinker_cookbook/eval/inspect_utils.py",
    "content": "\"\"\"\nShared utilities for inspect evaluation.\n\nThis module contains the common classes and functions used by both\nrun_inspect_evals.py and inspect_evaluator.py to avoid code duplication.\n\"\"\"\n\nimport logging\nimport time\nfrom collections.abc import Sequence\n\nimport tinker\nfrom inspect_ai.model import ChatCompletionChoice as InspectAIModelOutputChoice\nfrom inspect_ai.model import ChatMessage as InspectAIChatMessage\nfrom inspect_ai.model import ChatMessageAssistant as InspectAIChatMessageAssistant\nfrom inspect_ai.model import ChatMessageSystem, Content\nfrom inspect_ai.model import ContentReasoning as InspectAIContentReasoning\nfrom inspect_ai.model import ContentText as InspectAIContentText\nfrom inspect_ai.model import GenerateConfig as InspectAIGenerateConfig\nfrom inspect_ai.model import ModelAPI as InspectAIModelAPI\nfrom inspect_ai.model import ModelOutput as InspectAIModelOutput\nfrom inspect_ai.model import ModelUsage as InspectAIModelUsage\nfrom inspect_ai.model._registry import modelapi_register\nfrom inspect_ai.tool import ToolChoice as InspectAIToolChoice\nfrom inspect_ai.tool import ToolInfo as InspectAIToolInfo\nfrom termcolor import colored\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.exceptions import ConfigurationError\nfrom tinker_cookbook.renderers.base import ensure_list\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_model_usage(\n    tokenized_prompt: Sequence[int], responses: Sequence[tinker.SampledSequence]\n) -> InspectAIModelUsage:\n    \"\"\"\n    Given a tokenized prompt and a list of responses, return the number of tokens used/generated by the model.\n    \"\"\"\n    num_input_tokens = len(tokenized_prompt)\n    num_output_tokens = sum(len(r.tokens) for r in responses)\n    total_tokens = num_input_tokens + num_output_tokens\n    usage = InspectAIModelUsage(\n        input_tokens=num_input_tokens, output_tokens=num_output_tokens, total_tokens=total_tokens\n    )\n    return usage\n\n\ndef convert_inspect_messages(messages: list[InspectAIChatMessage]) -> list[renderers.Message]:\n    result: list[renderers.Message] = []\n    for m in messages:\n        content = m.content\n        if isinstance(content, str):\n            result.append(renderers.Message(role=m.role, content=content.strip()))\n        else:\n            # Structured content list from inspect_ai\n            parts: list[renderers.ContentPart] = []\n            for item in content:\n                if isinstance(item, InspectAIContentText):\n                    parts.append(renderers.TextPart(type=\"text\", text=item.text))\n                elif isinstance(item, InspectAIContentReasoning):\n                    parts.append(renderers.ThinkingPart(type=\"thinking\", thinking=item.reasoning))\n                else:\n                    logger.warning(\n                        f\"Skipping unsupported inspect content type: {type(item).__name__}\"\n                    )\n            # For non-assistant roles, flatten to string (reasoning in user/system is meaningless)\n            if m.role != \"assistant\" or not parts:\n                text = \" \".join(\n                    p[\"text\"] if p[\"type\"] == \"text\" else p[\"thinking\"]  # type: ignore[typeddict-item]\n                    for p in parts\n                ).strip()\n                result.append(renderers.Message(role=m.role, content=text))\n            else:\n                result.append(renderers.Message(role=m.role, content=parts))\n    return result\n\n\ndef _message_to_inspect_content(\n    message: renderers.Message,\n) -> list[Content]:\n    \"\"\"Convert a renderer Message's content parts to inspect_ai content types.\"\"\"\n    parts = ensure_list(message[\"content\"])\n    result: list[Content] = []\n    for part in parts:\n        if part[\"type\"] == \"thinking\":\n            result.append(InspectAIContentReasoning(reasoning=part[\"thinking\"]))\n        elif part[\"type\"] == \"text\":\n            result.append(InspectAIContentText(text=part[\"text\"]))\n        else:\n            logger.warning(f\"Skipping unsupported content part type in response: {part['type']}\")\n    return result\n\n\nclass InspectAPIFromTinkerSampling(InspectAIModelAPI):\n    \"\"\"\n    A model API wrapper that adapts tinker sampling clients to the inspect API interface.\n\n    This class can be initialized either with a model_path (for standalone use)\n    or with a sampling_client (for use in evaluators).\n    \"\"\"\n\n    def __init__(\n        self,\n        renderer_name: str,\n        model_name: str,\n        model_path: str | None = None,\n        sampling_client: tinker.SamplingClient | None = None,\n        base_url: str | None = None,\n        api_key: str | None = None,\n        api_key_vars: list[str] | None = None,\n        config: InspectAIGenerateConfig = InspectAIGenerateConfig(),\n        verbose: bool = False,\n        include_reasoning: bool = False,\n    ):\n        if api_key_vars is None:\n            api_key_vars = []\n        super().__init__(\n            model_name=model_name,\n            base_url=base_url,\n            api_key=api_key,\n            api_key_vars=api_key_vars,\n            config=config,\n        )\n\n        # Initialize sampling client\n        if sampling_client is not None:\n            self.sampling_client = sampling_client\n        elif model_path is not None:\n            service_client = tinker.ServiceClient(api_key=api_key)\n            self.sampling_client = service_client.create_sampling_client(model_path=model_path)\n        else:\n            raise ConfigurationError(\"Either model_path or sampling_client must be provided\")\n\n        # Initialize renderer and tokenizer\n        tokenizer = get_tokenizer(model_name)\n        self.renderer = renderers.get_renderer(name=renderer_name, tokenizer=tokenizer)\n        self.verbose = verbose\n        self.include_reasoning = include_reasoning\n\n    async def generate(\n        self,\n        input: list[InspectAIChatMessage],\n        tools: list[InspectAIToolInfo],\n        tool_choice: InspectAIToolChoice,\n        config: InspectAIGenerateConfig,\n    ) -> InspectAIModelOutput:\n        \"\"\"\n        The main interface that needs to be implemented to test a new model.\n        \"\"\"\n        if config.system_message:\n            input = [ChatMessageSystem(content=config.system_message)] + input\n        convo = convert_inspect_messages(input)\n        prompt = self.renderer.build_generation_prompt(convo)\n        num_responses = 1 if config.num_choices is None else config.num_choices\n        sampling_params = tinker.SamplingParams(\n            temperature=config.temperature if config.temperature is not None else 1.0,\n            max_tokens=config.max_tokens or 128,\n            stop=self.renderer.get_stop_sequences(),\n            top_p=config.top_p if config.top_p is not None else 1.0,\n            top_k=config.top_k if config.top_k is not None else -1,\n            seed=config.seed,\n        )\n\n        start_time = time.time()\n        sample_result = await self.sampling_client.sample_async(\n            prompt=prompt, sampling_params=sampling_params, num_samples=num_responses\n        )\n        sampled_token_sequences = sample_result.sequences\n\n        # Optional verbose output (only for standalone use)\n        if self.verbose:\n            prompt_text = colored(self.renderer.tokenizer.decode(prompt.to_ints()), \"green\")\n            logger.info(f\"[Prompt]\\n{prompt_text}\")\n            for i, seq in enumerate(sampled_token_sequences):\n                response_text = colored(self.renderer.tokenizer.decode(seq.tokens), \"red\")\n                logger.info(f\"[Response {i + 1}/{num_responses}]\\n{response_text}\")\n\n        end_time = time.time()\n\n        parsed_responses = [\n            self.renderer.parse_response(r.tokens)[0] for r in sampled_token_sequences\n        ]\n        if self.include_reasoning:\n            all_choices = [\n                InspectAIModelOutputChoice(\n                    message=InspectAIChatMessageAssistant(\n                        content=_message_to_inspect_content(r), model=self.model_name\n                    ),\n                    stop_reason=\"stop\",\n                )\n                for r in parsed_responses\n            ]\n        else:\n            all_choices = [\n                InspectAIModelOutputChoice(\n                    message=InspectAIChatMessageAssistant(\n                        content=renderers.get_text_content(r), model=self.model_name\n                    ),\n                    stop_reason=\"stop\",\n                )\n                for r in parsed_responses\n            ]\n        usage = get_model_usage(prompt.to_ints(), sampled_token_sequences)\n\n        return InspectAIModelOutput(\n            model=self.model_name, choices=all_choices, time=end_time - start_time, usage=usage\n        )\n\n\n# Register with inspect_ai's model registry.\n# Using modelapi_register instead of @modelapi decorator preserves the __init__ signature for pyright.\nmodelapi_register(InspectAPIFromTinkerSampling, \"tinker-sampling\")\n"
  },
  {
    "path": "tinker_cookbook/eval/inspect_utils_test.py",
    "content": "\"\"\"Tests for inspect_utils conversion functions.\"\"\"\n\nimport pytest\n\npytest.importorskip(\"inspect_ai\")\n\nfrom inspect_ai.model import ChatMessage as InspectAIChatMessage\nfrom inspect_ai.model import ChatMessageAssistant as InspectAIChatMessageAssistant\nfrom inspect_ai.model import ChatMessageUser as InspectAIChatMessageUser\nfrom inspect_ai.model import ContentReasoning as InspectAIContentReasoning\nfrom inspect_ai.model import ContentText as InspectAIContentText\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.eval.inspect_utils import _message_to_inspect_content, convert_inspect_messages\n\n# --- Output: _message_to_inspect_content ---\n\n\ndef test_message_to_inspect_content_with_thinking():\n    message = renderers.Message(\n        role=\"assistant\",\n        content=[\n            renderers.ThinkingPart(type=\"thinking\", thinking=\"let me think\"),\n            renderers.TextPart(type=\"text\", text=\"the answer\"),\n        ],\n    )\n    result = _message_to_inspect_content(message)\n    assert len(result) == 2\n    assert isinstance(result[0], InspectAIContentReasoning)\n    assert result[0].reasoning == \"let me think\"\n    assert isinstance(result[1], InspectAIContentText)\n    assert result[1].text == \"the answer\"\n\n\ndef test_message_to_inspect_content_string_content():\n    message = renderers.Message(role=\"assistant\", content=\"plain answer\")\n    result = _message_to_inspect_content(message)\n    assert len(result) == 1\n    assert isinstance(result[0], InspectAIContentText)\n    assert result[0].text == \"plain answer\"\n\n\ndef test_message_to_inspect_content_text_only_parts():\n    message = renderers.Message(\n        role=\"assistant\",\n        content=[renderers.TextPart(type=\"text\", text=\"just text\")],\n    )\n    result = _message_to_inspect_content(message)\n    assert len(result) == 1\n    assert isinstance(result[0], InspectAIContentText)\n    assert result[0].text == \"just text\"\n\n\ndef test_message_to_inspect_content_empty_thinking():\n    message = renderers.Message(\n        role=\"assistant\",\n        content=[\n            renderers.ThinkingPart(type=\"thinking\", thinking=\"\"),\n            renderers.TextPart(type=\"text\", text=\"answer\"),\n        ],\n    )\n    result = _message_to_inspect_content(message)\n    assert len(result) == 2\n    assert isinstance(result[0], InspectAIContentReasoning)\n    assert result[0].reasoning == \"\"\n\n\n# --- Input: convert_inspect_messages ---\n\n\ndef test_convert_inspect_messages_string_content():\n    messages: list[InspectAIChatMessage] = [\n        InspectAIChatMessageUser(content=\"hello\"),\n        InspectAIChatMessageAssistant(content=\"hi there\"),\n    ]\n    result = convert_inspect_messages(messages)\n    assert len(result) == 2\n    assert result[0][\"role\"] == \"user\"\n    assert result[0][\"content\"] == \"hello\"\n    assert result[1][\"role\"] == \"assistant\"\n    assert result[1][\"content\"] == \"hi there\"\n\n\ndef test_convert_inspect_messages_structured_assistant():\n    messages: list[InspectAIChatMessage] = [\n        InspectAIChatMessageAssistant(\n            content=[\n                InspectAIContentReasoning(reasoning=\"thinking...\"),\n                InspectAIContentText(text=\"answer\"),\n            ]\n        ),\n    ]\n    result = convert_inspect_messages(messages)\n    assert len(result) == 1\n    content = result[0][\"content\"]\n    assert isinstance(content, list)\n    assert len(content) == 2\n    assert content[0][\"type\"] == \"thinking\"\n    assert content[0][\"thinking\"] == \"thinking...\"  # type: ignore[typeddict-item]\n    assert content[1][\"type\"] == \"text\"\n    assert content[1][\"text\"] == \"answer\"  # type: ignore[typeddict-item]\n\n\ndef test_convert_inspect_messages_structured_non_assistant_flattens():\n    messages: list[InspectAIChatMessage] = [\n        InspectAIChatMessageUser(\n            content=[\n                InspectAIContentText(text=\"hello\"),\n                InspectAIContentText(text=\"world\"),\n            ]\n        ),\n    ]\n    result = convert_inspect_messages(messages)\n    assert len(result) == 1\n    assert result[0][\"role\"] == \"user\"\n    assert result[0][\"content\"] == \"hello world\"\n"
  },
  {
    "path": "tinker_cookbook/eval/run_inspect_evals.py",
    "content": "import asyncio\nimport logging\n\nimport chz\nimport tinker\n\nfrom tinker_cookbook import checkpoint_utils, model_info\nfrom tinker_cookbook.eval.inspect_evaluators import InspectEvaluator, InspectEvaluatorBuilder\nfrom tinker_cookbook.exceptions import ConfigurationError\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass Config(InspectEvaluatorBuilder):\n    model_path: str | None = None\n\n\nasync def main(config: Config):\n    logging.basicConfig(level=logging.INFO)\n\n    # Create a sampling client from the model path\n    service_client = tinker.ServiceClient()\n    model_path = config.model_path\n    model_name = config.model_name\n    renderer_name = config.renderer_name\n\n    # Resolve model name from checkpoint when needed, and validate explicit model_name if provided.\n    if model_path is not None:\n        rest_client = service_client.create_rest_client()\n        training_run = await rest_client.get_training_run_by_tinker_path_async(model_path)\n        if model_name is not None and model_name != training_run.base_model:\n            raise ConfigurationError(\n                f\"Model name {model_name} does not match training run base model {training_run.base_model}\"\n            )\n        model_name = model_name or training_run.base_model\n\n    if model_name is None:\n        raise ConfigurationError(\"model_path or model_name must be provided\")\n\n    # Resolve renderer with precedence: explicit config > checkpoint metadata > model default.\n    if renderer_name is None and model_path is not None:\n        renderer_name = await checkpoint_utils.get_renderer_name_from_checkpoint_async(\n            service_client, model_path\n        )\n    if renderer_name is None:\n        renderer_name = model_info.get_recommended_renderer_name(model_name)\n\n    config = chz.replace(config, model_name=model_name, renderer_name=renderer_name)\n\n    logger.info(f\"Using base model: {config.model_name}\")\n    logger.info(f\"Using renderer: {config.renderer_name}\")\n\n    sampling_client = service_client.create_sampling_client(\n        model_path=config.model_path, base_model=config.model_name\n    )\n\n    # Run the evaluation\n    logger.info(f\"Running inspect evaluation for tasks: {config.tasks}\")\n\n    # Create the inspect evaluator\n    evaluator = InspectEvaluator(config)\n    metrics = await evaluator(sampling_client)\n\n    # Print results\n    logger.info(\"Inspect evaluation completed!\")\n    logger.info(\"Results:\")\n    for metric_name, metric_value in metrics.items():\n        logger.info(f\"  {metric_name}: {metric_value}\")\n\n\nif __name__ == \"__main__\":\n    asyncio.run(chz.nested_entrypoint(main))\n"
  },
  {
    "path": "tinker_cookbook/example_data/conversations.jsonl",
    "content": "{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of France?\"}, {\"role\": \"assistant\", \"content\": \"The capital of France is Paris.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Explain photosynthesis in simple terms.\"}, {\"role\": \"assistant\", \"content\": \"Photosynthesis is the process plants use to make their own food. They take in sunlight, water, and carbon dioxide, and convert them into sugar (glucose) and oxygen. It's like the plant's way of cooking its own meal using sunlight as the energy source!\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful math tutor.\"}, {\"role\": \"user\", \"content\": \"What is 15 + 27?\"}, {\"role\": \"assistant\", \"content\": \"15 + 27 = 42\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Write a haiku about programming.\"}, {\"role\": \"assistant\", \"content\": \"Code flows like water\\nBugs hide in syntax shadows\\nDebugger finds peace\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who wrote 'Pride and Prejudice'?\"}, {\"role\": \"assistant\", \"content\": \"'Pride and Prejudice' was written by Jane Austen.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the boiling point of water at sea level?\"}, {\"role\": \"assistant\", \"content\": \"The boiling point of water at sea level is 100°C or 212°F.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a travel guide.\"}, {\"role\": \"user\", \"content\": \"What are the top attractions in Rome?\"}, {\"role\": \"assistant\", \"content\": \"Some top attractions in Rome are the Colosseum, Roman Forum, Pantheon, Trevi Fountain, and Vatican City.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Define gravity.\"}, {\"role\": \"assistant\", \"content\": \"Gravity is the force that attracts two bodies toward each other, most notably the attraction between the Earth and objects on it.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the square root of 81?\"}, {\"role\": \"assistant\", \"content\": \"The square root of 81 is 9.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful science explainer.\"}, {\"role\": \"user\", \"content\": \"Why is the sky blue?\"}, {\"role\": \"assistant\", \"content\": \"The sky appears blue because of Rayleigh scattering. Molecules in the atmosphere scatter shorter wavelengths of light (blue) more than longer wavelengths (red).\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Tell me a joke about computers.\"}, {\"role\": \"assistant\", \"content\": \"Why did the computer go to the doctor? Because it caught a virus!\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who was the first President of the United States?\"}, {\"role\": \"assistant\", \"content\": \"George Washington was the first President of the United States.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Explain what an atom is.\"}, {\"role\": \"assistant\", \"content\": \"An atom is the smallest unit of matter that retains the properties of an element. It is made up of protons, neutrons, and electrons.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Japan?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Japan is Tokyo.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 12 x 8?\"}, {\"role\": \"assistant\", \"content\": \"12 multiplied by 8 equals 96.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a poet.\"}, {\"role\": \"user\", \"content\": \"Write a short poem about the moon.\"}, {\"role\": \"assistant\", \"content\": \"The moonlight glows bright\\nGuiding travelers at night\\nSilent, silver friend\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who painted the Mona Lisa?\"}, {\"role\": \"assistant\", \"content\": \"The Mona Lisa was painted by Leonardo da Vinci.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the chemical symbol for gold?\"}, {\"role\": \"assistant\", \"content\": \"The chemical symbol for gold is Au.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"How many continents are there?\"}, {\"role\": \"assistant\", \"content\": \"There are seven continents: Africa, Antarctica, Asia, Australia, Europe, North America, and South America.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the largest mammal?\"}, {\"role\": \"assistant\", \"content\": \"The largest mammal is the blue whale.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a history expert.\"}, {\"role\": \"user\", \"content\": \"What caused World War I?\"}, {\"role\": \"assistant\", \"content\": \"World War I was triggered by the assassination of Archduke Franz Ferdinand in 1914, but underlying causes included militarism, alliances, imperialism, and nationalism.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Translate 'hello' to Spanish.\"}, {\"role\": \"assistant\", \"content\": \"'Hello' in Spanish is 'Hola'.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the tallest mountain in the world?\"}, {\"role\": \"assistant\", \"content\": \"Mount Everest is the tallest mountain in the world at 8,849 meters (29,032 feet).\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who developed the theory of relativity?\"}, {\"role\": \"assistant\", \"content\": \"The theory of relativity was developed by Albert Einstein.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What planet is known as the Red Planet?\"}, {\"role\": \"assistant\", \"content\": \"Mars is known as the Red Planet.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"How many days are in a leap year?\"}, {\"role\": \"assistant\", \"content\": \"A leap year has 366 days.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful dictionary.\"}, {\"role\": \"user\", \"content\": \"Define 'optimistic'.\"}, {\"role\": \"assistant\", \"content\": \"'Optimistic' means having a hopeful and positive outlook on the future.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 9 squared?\"}, {\"role\": \"assistant\", \"content\": \"9 squared is 81.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Where is the Great Barrier Reef located?\"}, {\"role\": \"assistant\", \"content\": \"The Great Barrier Reef is located off the coast of Queensland, Australia.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What gas do humans need to breathe?\"}, {\"role\": \"assistant\", \"content\": \"Humans need oxygen to breathe.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a fun fact generator.\"}, {\"role\": \"user\", \"content\": \"Give me a fun fact.\"}, {\"role\": \"assistant\", \"content\": \"Octopuses have three hearts, and two of them stop beating when they swim.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Canada?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Canada is Ottawa.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who invented the light bulb?\"}, {\"role\": \"assistant\", \"content\": \"Thomas Edison is often credited with inventing the light bulb, though many inventors contributed to its development.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the freezing point of water in Celsius?\"}, {\"role\": \"assistant\", \"content\": \"The freezing point of water in Celsius is 0°C.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who discovered penicillin?\"}, {\"role\": \"assistant\", \"content\": \"Penicillin was discovered by Alexander Fleming in 1928.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a sports commentator.\"}, {\"role\": \"user\", \"content\": \"Who won the FIFA World Cup in 2018?\"}, {\"role\": \"assistant\", \"content\": \"France won the FIFA World Cup in 2018.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 7 factorial?\"}, {\"role\": \"assistant\", \"content\": \"7 factorial (7!) = 7 × 6 × 5 × 4 × 3 × 2 × 1 = 5040.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Name the largest desert in the world.\"}, {\"role\": \"assistant\", \"content\": \"The largest desert in the world is the Antarctic Desert.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"How many players are on a soccer team?\"}, {\"role\": \"assistant\", \"content\": \"A soccer team has 11 players on the field.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What does DNA stand for?\"}, {\"role\": \"assistant\", \"content\": \"DNA stands for Deoxyribonucleic Acid.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a music teacher.\"}, {\"role\": \"user\", \"content\": \"What are the notes in a C major scale?\"}, {\"role\": \"assistant\", \"content\": \"The notes in a C major scale are C, D, E, F, G, A, B, and back to C.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Brazil?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Brazil is Brasília.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 45 divided by 5?\"}, {\"role\": \"assistant\", \"content\": \"45 divided by 5 is 9.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who was Cleopatra?\"}, {\"role\": \"assistant\", \"content\": \"Cleopatra was the last active ruler of the Ptolemaic Kingdom of Egypt.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the speed of light?\"}, {\"role\": \"assistant\", \"content\": \"The speed of light in a vacuum is approximately 299,792 kilometers per second (186,282 miles per second).\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a cooking assistant.\"}, {\"role\": \"user\", \"content\": \"How do you boil an egg?\"}, {\"role\": \"assistant\", \"content\": \"Place eggs in boiling water for about 9–12 minutes depending on desired firmness, then cool in ice water.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Germany?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Germany is Berlin.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 1000 divided by 25?\"}, {\"role\": \"assistant\", \"content\": \"1000 divided by 25 is 40.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the chemical symbol for silver?\"}, {\"role\": \"assistant\", \"content\": \"The chemical symbol for silver is Ag.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the tallest building in the world?\"}, {\"role\": \"assistant\", \"content\": \"The tallest building in the world is the Burj Khalifa in Dubai, standing at 828 meters (2,717 feet).\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who wrote 'Hamlet'?\"}, {\"role\": \"assistant\", \"content\": \"'Hamlet' was written by William Shakespeare.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 64 divided by 8?\"}, {\"role\": \"assistant\", \"content\": \"64 divided by 8 is 8.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a weather expert.\"}, {\"role\": \"user\", \"content\": \"What causes thunderstorms?\"}, {\"role\": \"assistant\", \"content\": \"Thunderstorms are caused by rapidly rising warm, moist air that cools and condenses, forming cumulonimbus clouds with lightning and thunder.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the largest ocean?\"}, {\"role\": \"assistant\", \"content\": \"The Pacific Ocean is the largest ocean on Earth.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Name the process by which plants lose water vapor.\"}, {\"role\": \"assistant\", \"content\": \"The process is called transpiration.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 11 times 11?\"}, {\"role\": \"assistant\", \"content\": \"11 times 11 equals 121.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who painted 'Starry Night'?\"}, {\"role\": \"assistant\", \"content\": \"'Starry Night' was painted by Vincent van Gogh.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a riddle master.\"}, {\"role\": \"user\", \"content\": \"Give me a short riddle.\"}, {\"role\": \"assistant\", \"content\": \"What has keys but can't open locks? — A piano.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the main gas in the Earth's atmosphere?\"}, {\"role\": \"assistant\", \"content\": \"The main gas in Earth's atmosphere is nitrogen, making up about 78%.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Italy?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Italy is Rome.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who is known as the Father of Computers?\"}, {\"role\": \"assistant\", \"content\": \"Charles Babbage is known as the Father of Computers.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the smallest prime number?\"}, {\"role\": \"assistant\", \"content\": \"The smallest prime number is 2.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a nutrition coach.\"}, {\"role\": \"user\", \"content\": \"What vitamin do we get from sunlight?\"}, {\"role\": \"assistant\", \"content\": \"We get vitamin D from sunlight.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 3 cubed?\"}, {\"role\": \"assistant\", \"content\": \"3 cubed (3³) is 27.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Where is Mount Kilimanjaro located?\"}, {\"role\": \"assistant\", \"content\": \"Mount Kilimanjaro is located in Tanzania.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who was the first man on the moon?\"}, {\"role\": \"assistant\", \"content\": \"Neil Armstrong was the first man to walk on the moon in 1969.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the largest planet in our solar system?\"}, {\"role\": \"assistant\", \"content\": \"The largest planet in our solar system is Jupiter.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a movie buff.\"}, {\"role\": \"user\", \"content\": \"Who directed 'Jurassic Park'?\"}, {\"role\": \"assistant\", \"content\": \"'Jurassic Park' was directed by Steven Spielberg.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 200 minus 75?\"}, {\"role\": \"assistant\", \"content\": \"200 minus 75 is 125.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What language has the most native speakers?\"}, {\"role\": \"assistant\", \"content\": \"Mandarin Chinese has the most native speakers.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"How many bones are in the adult human body?\"}, {\"role\": \"assistant\", \"content\": \"The adult human body has 206 bones.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Egypt?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Egypt is Cairo.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the first element on the periodic table?\"}, {\"role\": \"assistant\", \"content\": \"The first element on the periodic table is Hydrogen (H).\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a storyteller.\"}, {\"role\": \"user\", \"content\": \"Tell me a one-sentence story.\"}, {\"role\": \"assistant\", \"content\": \"A lost balloon drifted into the sky, carrying with it the dreams of a small child.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the currency of the United Kingdom?\"}, {\"role\": \"assistant\", \"content\": \"The currency of the United Kingdom is the British Pound Sterling (GBP).\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the hardest natural substance on Earth?\"}, {\"role\": \"assistant\", \"content\": \"The hardest natural substance on Earth is diamond.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 14 multiplied by 6?\"}, {\"role\": \"assistant\", \"content\": \"14 multiplied by 6 is 84.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a computer science tutor.\"}, {\"role\": \"user\", \"content\": \"What does HTML stand for?\"}, {\"role\": \"assistant\", \"content\": \"HTML stands for HyperText Markup Language.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Which planet is closest to the Sun?\"}, {\"role\": \"assistant\", \"content\": \"The planet closest to the Sun is Mercury.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who was Albert Einstein?\"}, {\"role\": \"assistant\", \"content\": \"Albert Einstein was a theoretical physicist best known for developing the theory of relativity.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Australia?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Australia is Canberra.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful editor.\"}, {\"role\": \"user\", \"content\": \"Correct this sentence: 'She go to school everyday.'\"}, {\"role\": \"assistant\", \"content\": \"The correct sentence is: 'She goes to school every day.'\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the fastest land animal?\"}, {\"role\": \"assistant\", \"content\": \"The fastest land animal is the cheetah.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who discovered gravity?\"}, {\"role\": \"assistant\", \"content\": \"Sir Isaac Newton is credited with discovering the law of gravity.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 90 divided by 9?\"}, {\"role\": \"assistant\", \"content\": \"90 divided by 9 is 10.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Spain?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Spain is Madrid.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a myth buster.\"}, {\"role\": \"user\", \"content\": \"Is lightning hotter than the Sun?\"}, {\"role\": \"assistant\", \"content\": \"Yes, lightning can be hotter than the surface of the Sun. A lightning bolt can reach around 30,000 Kelvin, while the Sun’s surface is about 5,500 Kelvin.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the main ingredient in guacamole?\"}, {\"role\": \"assistant\", \"content\": \"The main ingredient in guacamole is avocado.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What year did World War II end?\"}, {\"role\": \"assistant\", \"content\": \"World War II ended in 1945.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the largest internal organ in the human body?\"}, {\"role\": \"assistant\", \"content\": \"The largest internal organ in the human body is the liver.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 2 to the power of 5?\"}, {\"role\": \"assistant\", \"content\": \"2 to the power of 5 (2⁵) is 32.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a language teacher.\"}, {\"role\": \"user\", \"content\": \"Translate 'Good morning' to French.\"}, {\"role\": \"assistant\", \"content\": \"'Good morning' in French is 'Bonjour'.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who was the first female Prime Minister of the UK?\"}, {\"role\": \"assistant\", \"content\": \"Margaret Thatcher was the first female Prime Minister of the United Kingdom.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Russia?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Russia is Moscow.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"How many stripes are on the US flag?\"}, {\"role\": \"assistant\", \"content\": \"The US flag has 13 stripes, representing the original 13 colonies.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is photosynthesis?\"}, {\"role\": \"assistant\", \"content\": \"Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to make food and release oxygen.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who painted the ceiling of the Sistine Chapel?\"}, {\"role\": \"assistant\", \"content\": \"Michelangelo painted the ceiling of the Sistine Chapel.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of China?\"}, {\"role\": \"assistant\", \"content\": \"The capital of China is Beijing.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a puzzle master.\"}, {\"role\": \"user\", \"content\": \"Give me a quick puzzle.\"}, {\"role\": \"assistant\", \"content\": \"I speak without a mouth and hear without ears. I have no body, but I come alive with the wind. What am I? — An echo.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Mexico?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Mexico is Mexico City.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 5 times 16?\"}, {\"role\": \"assistant\", \"content\": \"5 times 16 is 80.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a space expert.\"}, {\"role\": \"user\", \"content\": \"What is a black hole?\"}, {\"role\": \"assistant\", \"content\": \"A black hole is a region of space where gravity is so strong that not even light can escape from it.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who wrote 'The Odyssey'?\"}, {\"role\": \"assistant\", \"content\": \"'The Odyssey' is attributed to the ancient Greek poet Homer.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the largest island in the world?\"}, {\"role\": \"assistant\", \"content\": \"The largest island in the world is Greenland.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 1 kilometer in meters?\"}, {\"role\": \"assistant\", \"content\": \"1 kilometer equals 1,000 meters.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What organ pumps blood through the body?\"}, {\"role\": \"assistant\", \"content\": \"The heart pumps blood through the body.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a geography quiz master.\"}, {\"role\": \"user\", \"content\": \"What is the capital of Argentina?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Argentina is Buenos Aires.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"How many colors are in a rainbow?\"}, {\"role\": \"assistant\", \"content\": \"There are seven colors in a rainbow: red, orange, yellow, green, blue, indigo, and violet.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What year did the Titanic sink?\"}, {\"role\": \"assistant\", \"content\": \"The Titanic sank in 1912.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the chemical formula for water?\"}, {\"role\": \"assistant\", \"content\": \"The chemical formula for water is H₂O.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 144 divided by 12?\"}, {\"role\": \"assistant\", \"content\": \"144 divided by 12 is 12.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a biology tutor.\"}, {\"role\": \"user\", \"content\": \"What is the function of red blood cells?\"}, {\"role\": \"assistant\", \"content\": \"Red blood cells carry oxygen from the lungs to the body and return carbon dioxide to the lungs to be exhaled.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who was the first woman in space?\"}, {\"role\": \"assistant\", \"content\": \"The first woman in space was Valentina Tereshkova from the Soviet Union in 1963.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of South Korea?\"}, {\"role\": \"assistant\", \"content\": \"The capital of South Korea is Seoul.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the square root of 144?\"}, {\"role\": \"assistant\", \"content\": \"The square root of 144 is 12.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the largest bird in the world?\"}, {\"role\": \"assistant\", \"content\": \"The largest bird in the world is the ostrich.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a literature assistant.\"}, {\"role\": \"user\", \"content\": \"Who wrote 'To Kill a Mockingbird'?\"}, {\"role\": \"assistant\", \"content\": \"'To Kill a Mockingbird' was written by Harper Lee.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 250 minus 125?\"}, {\"role\": \"assistant\", \"content\": \"250 minus 125 is 125.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Which blood type is known as the universal donor?\"}, {\"role\": \"assistant\", \"content\": \"Blood type O negative is known as the universal donor.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Greece?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Greece is Athens.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the chemical symbol for iron?\"}, {\"role\": \"assistant\", \"content\": \"The chemical symbol for iron is Fe.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"How many planets are in the solar system?\"}, {\"role\": \"assistant\", \"content\": \"There are eight planets in the solar system.\"}]}\n{\"messages\": [{\"role\": \"system\", \"content\": \"You are a joke teller.\"}, {\"role\": \"user\", \"content\": \"Tell me a knock-knock joke.\"}, {\"role\": \"assistant\", \"content\": \"Knock knock.\\nWho's there?\\nLettuce.\\nLettuce who?\\nLettuce in, it's cold out here!\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Turkey?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Turkey is Ankara.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is 8 times 15?\"}, {\"role\": \"assistant\", \"content\": \"8 times 15 is 120.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"Who developed the polio vaccine?\"}, {\"role\": \"assistant\", \"content\": \"The polio vaccine was developed by Jonas Salk.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the capital of Kenya?\"}, {\"role\": \"assistant\", \"content\": \"The capital of Kenya is Nairobi.\"}]}\n{\"messages\": [{\"role\": \"user\", \"content\": \"What is the smallest country in the world?\"}, {\"role\": \"assistant\", \"content\": \"The smallest country in the world is Vatican City.\"}]}\n"
  },
  {
    "path": "tinker_cookbook/example_data/multilingual.txt",
    "content": "وقال، ماما، لقد عدت للمنزل.\nИ той каза: Мамо, у дома съм.\nund er hat gesagt, Mama ich bin daheim.\nΚαι είπε, Μαμά, έφτασα στο σπίτι.\nAnd he said, Mama, I'm home.\nY él dijo: Mamá, estoy en casa.\nEt il a dit, maman, je suis à la maison.\nऔर उसने कहा, माँ, मैं घर आया हूं।\nИ он сказал: Мама, я дома.\nNaye akasema, Mama, niko nyumbani.\nและเขาพูดว่า, ม่าม๊า ผมอยู่บ้าน\nVe Anne, evdeyim dedi.\nاور اس نے کہا امّی، میں گھر آگیا ہوں۔\nVà anh ấy nói, Mẹ, con đã về nhà.\n他说，妈妈，我回来了。\nحسنا ، لم أكن أفكر حتى حول ذلك ، لكن كنت محبطاً تماما ،وأنهيت الحديث معه مرة ثانية .\nЕ, аз дори не мислех за това, но бях толкова разочарована, а в крайна сметка отново разговарях с него.\nNun, daran dachte ich nicht einmal, aber ich war so frustriert, dass ich am Ende doch mit ihm redete.\nΛοιπόν, δεν το σκέφτηκα καν, αλλά ήμουν τόσο απογοητευμένος, και κατέληξα να του μιλάω και πάλι.\nWell, I wasn't even thinking about that, but I was so frustrated, and, I ended up talking to him again.\nBien, ni estaba pensando en eso, pero estaba tan frustrada y empecé a hablar con él de nuevo.\nEh bien, je ne pensais même pas à cela, mais j'étais si frustré, et j'ai fini par lui reparler.\nहालांकि मैं इसके बारे में सोच भी नहीं रहा था लेकिन मैं इतना परेशान था कि मुझे वापस उससे बात करनी ही पड़ेगी\nНу, я даже не думал об этом, но я был так разочарован, что в конце концов опять поговорил с ним.\nNaam, sikukuwa nafikiri juu ya hilo, lakini nilichanganyikiwa sana, na, hatimaye nikaendelea kuzungumza naye tena.\nดี, ฉันไม่ได้คิดอะไรเกี่ยวกับเรื่องนี้, แต่ฉันก็ผิดหวัง, และ, ฉันก็กลับไปคุยกับเขาอีกครั้ง\nPekala, bunu hiç düşünmemiştim ancak kafam çok karıştı ve onunla tekrar konuşmadım.\nمیں اس کے بارے میں نہیں سوچھ رہی تھی ، لیکن میں اتنی مایوس تھی کہ اس سے دوبارہ بات کرنے لگی۔\nVâng, tôi thậm chí không nghĩ về điều đó, nhưng tôi đã rất thất vọng, và, tôi lại nói chuyện với anh ta lần nữa.\n嗯，我根本没想过，但是我很沮丧，最后我又和他说话了。\nواعتقدت أن ذلك شرف لي ، ولا يزال ، ولايزال ، كنت الوحيد برقم تسعة اثنان اثنان إي أكس أو والذي كان مجال مهنتي في سلاح الجو .\nИ аз мислех, че това е привилегия, и тя все още е, аз бях единственият бивш на 922, което беше моето поприще във въздушните сили AFFC.\n„Und ich dachte, das wäre ein Privileg, und das ist es immer noch, das ist esi mmer noch. Ich war der einzige 229 Ex-O, was mein AFCF Air Force-Karriere-Feld war.“\nΚαι σκέφτηκα ότι ήταν ένα προνόμιο, και είναι ακόμα, ήμουν ο μόνος 922 Ex-O που ήταν το πεδίο μου για Καριέρα στην Αεροπορία.\nAnd I thought that was a privilege, and it's still, it's still, I was the only nine two-two Ex-O which was my AFFC Air Force Career field.\nY pensé que era un privilegio, y todavía es, todavía es, yo era el único 922 Ex-O que estaba en mi campo de carrera AFFC de la Fuerza Aérea.\nEt je pensais que c'était un privilège, et ça l'est toujours, toujours, j'étais le seul 922 Ex-O, c'est là que j'avais effectué ma carrière dans l'AFFC Air Force.\nऔर मैंने सोचा कि यह एक विशेषाधिकार था, और यह अभी भी है, यह अभी भी है, मैं केवल नौ दो-दो Ex-O जो मेरे AFFC एयर फोर्स कैरियर फील्ड थे।\nА я считал это привилегией, и всё-же,  и всё-же я был единственный оперативный офицер категории девять два два - это была моя профессиональная специальность при Командовании ВВС.\nNami nilifikiri kwamba ilikuwa ni fursa, na bado, bado, nilikuwa pekee wa tisa mbili na mbili Ex-O ambayo ilikuwa uwanja wangu wa AFFC Air Force Career.\nและฉันนึกว่านั้นเป็นสิทธิพิเศษและมันยังคงเป็น ฉันเป็นเก้า สอง-สอง เอ็กซ์-โอ ซึ่งเป็นสาขาอาชีพ AFFC Air Force ของฉัน\nVe bunun bir ayrıcalık olduğunu sanıyordum, ve hala, hala benim, AFFC Hava Kuvvetleri Kariyer alanım olan dokuz tane iki iki X-O'ydu.\nاور میں نے سوچا کہ یہ ایک امتیاز تھا، اور یہ ابھی بھی ہے، یہ ابھی بھی ہے، میں صرف نو دو دو تھا، جو میرا AFFC ایئر فورس کیریئر فیلڈ تھا.\nVà tôi nghĩ đó là một đặc ân, và nó vẫn còn, nó vẫn là, tôi là chín hai hai Ex-O là lĩnh vực nghề nghiệp không quân AFFC của tôi.\n而我当初认为这是一个特权，我现在仍然这样想，我是唯一的922 Ex-O，也是我的AFFC空军职业生涯。\nأخبروني ،إيه، أنه سيتم استدعائي من قبل شاب في النهاية سيقابلني .\nКазаха ми, че накрая ще ми се обади човек, за да се срещнем.\nSie sagten mir, dass ich am Ende von einem Kerl angerufen werden würde, um mich zu treffen.\nΜου είπαν ότι, θα έπρεπε να καλέσω έναν άντρα στο τέλος για να συναντηθούμε.\nThey told me that, uh, that I would be called in a guy at the end for me to meet.\nMe dijeron que, esto... que un tipo me llamaría al final para vernos.\nIls m'ont dit qu'à la fin, on m'amènerait un homme pour que je le rencontre.\nउन्होंने मुझसे कहा की अंत में एक व्यक्ति से मीटिंग करनी होगी ।\nОни сказали, что меня вызовут , чтобы встретиться с парнем в конце.\nWalinieleza ya kwamba mwishowe ningeitiwa jamaa fulani ambaye tungepatana naye.\nพวกเขาบอกฉันว่าเขาจะเรียกคน ๆ หนึ่งเข้ามาในตอนท้ายให้ฉันพบ\nSonunda tanışmam için bir adam tarafından çağrılmamı söylediler.\nانہوں نے مجھے بتایا تھا کے آخر میں مجھ سے ملنے کے لۂے ایک آدمی بلایا جاۂے گا\nHọ nói với tôi rằng, ừ, rằng tôi sẽ được gọi bởi một chàng trai vào cuối cùng để gặp nhau.\n他们告诉我，呃，我最后会被叫到一个人那里去见面。\nهناك الكثير تستطيع التحدث عنه  وأنا سوف أتاجاوز ذلك تماما .\nИма толкова много, което може да се разкаже за това, че просто ще го пропусна.\nEs gibt so viel was ich darüber erzählen könnte, ich überspringe das einfach.\nΥπάρχουν τόσα πολλά που θα μπορούσες να μιλήσεις γι 'αυτό απλά θα τα παραλείψω.\nThere's so much you could talk about on that I'll just skip that.\nHay tanto que se puede decir sobre eso, que sencillamente me voy a saltar eso.\nIl y a tellement de choses dont vous pourriez parler que je vais juste m'en passer.\nइतना है कि आप इसके बारे में बात कर सकते हैं कि मैं इसे छोड़ूँगा|\nОб этом можно так много говорить, что я опущу подробности.\nKuna mengi ambayo unaweza kuzungumzia kuhusu hilo lakini  nitaachana nayo tu.\nมันมีอีกมากที่คุณสามารถพูดคุยเกี่ยวกับสิ่งนั้น ฉันจะข้ามไปละกัน\nBu konu hakkında söyleyebileceğin çok şey var pas geçiyorum.\nبہت اتنا ہے کہ آپ اس کے بارے میں بات کر سکتے ہیں. میں صرف اس کو چھوڑ دونگا.\nCó rất nhiều điều bạn có thể nói mà tôi sẽ chỉ bỏ qua điều đó thôi.\n你可以讲的太多了，我就不提了。\nلم أعرف من أجل ماذا أنا ذاهب أو أي شىْ ، لذلك كان علي أن أبلغ عن مكان محددا في واشنطن.\nНе знаех за какво отивам и въобще нищо, но трябваше да се явя на определено място във Вашингтон.\nIch wusste nicht was ich vorhatte oder so, ich musste mich an einen bestimmten Ort in Washington melden.\nΔεν ήξερα που πήγαινα ή κάτι τέτοιο, έτσι έπρεπε να αναφέρω ένα καθορισμένο μέρος στην Ουάσινγκτον.\nI didn't know what I was going for or anything, so was to report to a designated place in Washington.\nNo sabía para qué iba ni nada, así que iba a informarmar a un lugar designado en Washington.\nJe ne savais pas dans quoi je me lançais, donc j'allais être rattaché à un lieu désigné à Washington.\nमुझे नहीं पता था कि मैं क्या कर रहा था या कुछ भी, इसलिए वाशिंगटन में किसी निर्दिष्ट स्थान पर रिपोर्ट करना था ।\nЯ не знал, что мне предстояло сделать и все такое, так что я должен был сообщить в указанное место в Вашингтоне.\nSikujua nini nilichoendea au kitu chochote, hivyo ilikuwa na ni ripoti mahali paliopangwa huko Washington.\nฉันไม่รู้ว่าฉันไปเพื่ออะไรหรือเพื่อสิ่งใด ดังนั้นแค่รายงานเกี่ยวกับสถานที่ที่ระบุในวอชิงตัน\nNe için gittiğimi falan bilmiyordum, Washington'da belirtilen bir yere rapor vermem gerekiyordu.\nمجھے نہیں معلوم تھا کہ میں کیا کرنے جا رہا تھا، تو واشنگٹن میں ایک نامزد جگہ میں اطلاح دینی تھی.\nTôi đã không biết mình đang hướng tới mục đích gì, ngay cả việc báo cáo cho một địa chỉ cho trước ở Washington.\n我不知道我要去干什么还是什么的，所以就去华盛顿指定的地方报到。\nكان ذلك الشيْ الأساسي الوحيد الذي أردنا أن ننقذه لأنه لم يكن هناك أي طريقة لإلقاء قنبلة هيدروجينية بقوة 20 ميغا 30 C 124 .\nТова беше основното нещо, което искахме да спасим, тъй като нямаше начин да хвърлим 20-мегатонна водородна бомба от 30, C124.\nEs war das Wichtigste was wir sichern wollten da es keine Möglichkeit gab eine 20 Megatonnen- H- Bombe ab zu werfen von einem 30, C124.\nΑυτό ήταν το κύριο πράγμα που θέλαμε να σώσουμε αφού δεν υπήρχε κανένας τρόπος να ρίξουμε μια 20 τόνων βόμβα υδρογόνου από ένα 30, ένα C124.\nThat was the primary thing we wanted to save since there wasn't any way to dump a 20-megaton H-bomb off a 30, a C124.\nEso es era lo más importante que queríamos salvar, ya que no había forma de lanzar  una bomba hidrógena desde unos 30, un C124.\nC'est la première chose que nous voulions sauver puisqu'il n'y avait aucun moyen de jeter une bombe H de 20 mégatonnes du haut d'un 30, un C124.\nयही प्राथमिक चीज थी जिसे हम बचाना चाहते थे क्युकी एक २०-मेगाटन हाइड्रोजन बम एक ३०  एक सी १२४ को फेकने का कोई और उपाय नहीं था |\nЭто было главное, что мы старались сохранить, так как возможности сбросить водородную бомбу в 20 мегатонн с 30-ти, то есть с С-124, не было.\nhicho ndio kilikuwa kitu cha msingi tulichotka kuokoa kwa sababu haingewezakana kutupa bomu la hidrojeni megatoni 20  kwenye C124.\nมันคือสิ่งแรกที่เราอยากจะเก็บไว้เพราะไม่มีวิธีที่จะทิ้งระเบิดไฮโดรเจน 20 ล้านตันลง 30, a C124 ได้\n20 megatonluk H-bombasını 30 tane C124'ün üzerinden  atmanın bir yolu olmadığından, o kurtarmak istediğimiz ilk şeydi.\nہم اس ہی کو بچانا چاھتے تھے کیونکہ 20 میگا ٹن بم کو پھینکنے کا کوۂی طریکہ نہیں تھا\nĐó là thứ đầu tiên chúng tôi muốn giữ lại vì không có cách nào để ném một quả bom nhiệt hạch nặng 20 triệu tấn xuống từ máy bay C124, 30 cả.\n这是我们首要需要解救的事情，因为没有任何方法可以将一枚20兆吨重的氢弹从一架30，一架C124上卸下。\nلم يكن عليه أن يذهب.\nТой не можеше да си тръгне.\nEr konnte nicht gehen.\nΔεν πήγε να πάει.\nHe didn't get to go.\nÉl no pudo ir.\nIl n'a pas pu y aller.\nउसे जाने नहीं मिला\nЕму не удалось уйти.\nHakuweza kuenda.\nเขาอดไป\nGitmesine gerek yoktu.\nوہ جا نہیں سکا ۔\nAnh ấy không cần phải bắt đầu.\n他没去到。\nوكنت بحالة حسنة، وكان ذلك!\nИ аз казах ОК, и това беше!\nUnd ich dachte OK und das war es dann!\nΚαι ήμουν εντάξει, και αυτό ήταν!\nAnd I was like OK, and that was it!\nY yo estaba bien, ¡y eso fue todo!\nEt j'étais comme OK, et c'était tout !\nऔर मै तो था की भई ठीक है , और बस वही था ।\nА я говорю, типа, «ок», ну и все.\nMimi nilikuwa nimeridhika nayo.\nและฉันก็แบบว่าตอบตกลงและมันก็เท่านั่น!\nBenim için sorun yoktu, ve hepsi bu kadardı!\nاور میں یہ سوچ رہا تھا کہ چلو ٹھیک ہے اور بس یہی تھا۔\nVà tôi giống như ỔN, và đó là thế!\n我没事，就这样！\nلذلك أنا فعلا غير متأكد لماذا.\nТака че не съм много сигурен защо.\nIch bin mir also nicht wirklich sicher warum.\nΈτσι δεν είμαι σίγουρος γιατί.\nSo I'm not really sure why.\nAsí que no tengo muy claro por qué.\nDonc je ne sais pas vraiment pourquoi.\nइसलिए मुझे वाकई यकीन नहीं है कि क्यों ।\nИ я не самом деле не знаю, почему.\nHivyo, sijui kwa kweli nu kwa nini.\nฉันก็เลยไม่แน่ใจจริงๆ ว่าทำไม\nNeden olduğunu gerçekten bilmiyorum.\nتو میں بالکل پر یقین نہیں ہوں کہ کیوں۔\nVì vậy tôi không chắc tại sao nữa.\n我不知道为什么。\nلقد كنت هناك للتو محاولاً أن أعرف الأمر.\nАз просто бях там и се опитвах да разбера.\nIch war nur da und habe versucht, es zu verstehen.\nΉμουν απλά εκεί προσπαθώντας να καταλάβω.\nI was just there just trying to figure it out.\nSolo estaba allí tratando de resolverlo.\nJ'étais juste là juste à essayer de comprendre.\nमैं सिर्फ यह पता लगाने की कोशिश कर रहा था।\nЯ просто пытался понять, что происходит.\nNilikuwa tu uko kujaribu kujua ni lipi?\nฉันก็แค่อยู่ที่นั่นเพื่อลองคิดเกี่ยวกับมัน\nSadece bir yolunu bulmaya çalışıyordum.\nمیں وہاں صرف گتھی سلجھانے کی کوشش کر رہی تھی۔\nTôi chỉ ở nguyên đó, cố gắng hình dung ra vấn đề.\n我就在那里尝试解决这个问题。\nهذه هي فاني فلونو ، هي ترعرعت في أوغستا ، جورجيا ، وهي ستتحدث عن بعض قصص  طفولتها . ،\nТова е Фани Флоно, а тя е израснала в Огъста, Джорджия, и ще разкаже някои истории от детството си.\nDas ist Fannie Flono, sie ist aufgewachsen in Ag--Augusta, GA, und wird über einige Erinnerungen aus ihrer Kindheit sprechen.\nΑυτή είναι η Fannie Flono, και μεγάλωσε στην Αγκούστα, στην Γεωργία, και πρόκειται να μιλήσει για μερικές ιστορίες από την παιδική της ηλικία.\nThis is Fannie Flono, and she grew up in Ag-- Augusta, GA, and she is going to talk about some stories from her childhood.\nEsta es Fannie Flono. Creció en Ag-- Augusta, GA, y va a hablar sobre algunas historias de su infancia.\nVoivi Fannie Flono, et elle a grandi à Ag- Augusta, GA, et elle va parler de quelques histoires de son enfance.\nये है फन्नी फ्लोनो, और ये पली बढ़ी हुई है आग- ऑगस्टा , जी ए में , और वह अपने बचपन के बारे में कुछ कहानिया बताने वाली है ।\nЭто Фанни Флоно, она выросла в Огасте, Джорджия, и расскажет несколько историй из своего детства.\nHuyu ni Fannie Flono, na yeye alikulia Ag - Augusta, GA, na yeye atakuja kuzungumza kuhusu baadhi ya hadithi kutoka utoto wake.\nนี่คือ Fannie Flono และเธอเติบโตขึ้นมาใน in Ag-- Augusta, GA และเธอกำลังจะพูดถึงเรื่องราวบางอย่างจากวัยเด็กของเธอ\nBu Fannie Flono, Augusta, GA da büyümüş ve çocukluğundan kalma birkaç hikaye anlatacak.\nیہ فینی فلونیو ہے,اور وہ  آگ- استانا میں بڑی ہوی،جارجیا,اور وہ اپنے بچپن سے کچھ کہانیوں کے بارے میں بات کرنے جا رہی ہے.\nĐây là Fannie Flono, và cô ấy đã lớn lên ở Ag - Augusta, GA, và cô ấy sẽ nói về một số câu chuyện từ thời thơ ấu của cô ấy.\n这是Fannie Flono，她在佐治亚州奥古斯塔长大，她会讲述她童年时的一些故事。\nو ، إيه، هم  فعلا توقفوا عن زيارة العائلة لأنهم قرروا تماماأنهم سيصبحوا بيضاً .\nИ, всъщност те спряха да посещават семейството, защото просто бяха решени, че ще бъдат бели.\nUnd, ähm, sie hörten tatsächlich auf, die Familie zu besuchen, weil sie gerade waren, nur bestimmt, dass sie weiß werden würden.\nΚαι, πράγματι, σταμάτησαν για επίσκεψη στην οικογένεια επειδή απλά είχαν αποφασίσει ότι θα ήταν λευκοί.\nAnd, uh, they kind of actually stopped visiting the family because they were just, just determined that they were going to be white.\nY, oye, de hecho dejaron de visitar a la familia porque estaban decididos que iban a ser blancos.\net, uh, ils ont, d'une certaine façon, arrêté en fait de visiter la famille parce qu'ils étaient juste, juste déterminés qu'ils allaient être blancs.\nऔर, उह, वे वास्तव में परिवार का दौरा बंद कर देते हैं क्योंकि वे केवल, केवल निर्धारित कर चुके हैं कि वे वाईट होने जा रहे हैं।\nИ... ээ... они как бы на самом деле перестали навещать родственников, потому что они были просто... просто уверены, что будут белыми.\nHatimaye walisita kuitembelea familia sababu walikuwa wameumua kwamba wanataka kuishi kama watu weupe.\nและ อ่า พวกเขาไม่ค่อยจะแวะมาเยี่ยมครอบครัวเพราะพวกเขาเพิ่งตัดสินใจว่าพวกเขาจะเป็นคนผิวขาว\nVe, ah, aslında aileyi ziyaret etmeyi bıraktılar gibi çünkü beyaz olmaya tamamen kararlıydılar.\nچال یہ ہے کہ مجھے بستی کا نیا سردار کم اور ان دائیوں میں سے ایک زیادہ سمجھا جائے جنہیں  وان ٹریپ کے بچوں نے ماریہ سے پہلے مار ڈالا تھا۔\nVà, ừ, họ thực sự đã dừng lại chuyến thăm gia đình bởi vì họ chỉ, chỉ cần xác định rằng họ sẽ trở nên trắng.\n他们停止了跟这家交朋友，因为他们决定了当白人。\nلقد وضعت خمسة فصائل من U2's\nА аз пуснах пет отряда от U2.\nUnd ich hatte äh, fünf Abteilungen von U2\nΚαι είχα βάλει, πέντε διμοιρίες από τα U2.\nAnd I had put uh, five detachments out of U2's.\nY había echado a cinco destacamentos de U2.\nEt j'avais mis euh, cinq détachements à l'extérieur de l'U2.\nऔर मैंने पांच टुकड़ीयाँ U2 से बाहर रखी था,।\nИ мне пришлось сделать 5 разъединений из U2.\nNa nilikuwa nimeweka, milango mitano nje ya U2's.\nและฉันได้ใส่ เอ่อ, ห้ากองออกมาจาก ยู2\nVe U2'lerden beş tane deşifre koymuştum.\nاور میں نے 5 میں سے دو علیحدہ کئے تھے۔\nVà tôi đã đặt uh, năm sự tách rời khỏi U2.\n我已经从U2乐队中去除了5个超然\nكنت الشخص الوحيد الذي ،إيه، يدير منظمات الاختبار في غرف الارتفاع الصغيرة.\nАз бях единственият, който някога е работил с регулаторите за теста в миниатюрните височинни камери.\nIch war der Einzige der jemals die Regler für den Test in den Miniaturhöhenkammern laufen ließ.\nΕγώ ήμουν ο μόνος που έτρεχε πάντα τις ρυθμίσεις για δοκιμή στους μικροσκοπικούς θαλάμους υψομέτρου.\nI was the only one that uh, ever run the regulators for the, the test in the miniature altitude chambers.\nYo era el único que alguna vez ejecutaba los reguladores para la prueba en las cámaras de altitud en miniatura.\nJ'étais le seul à, euh, toujours faire marcher les régulateurs pour le, le test dans les chambres d'altitude miniatures.\nवो सिर्फ मैं ही था जिन्होंने मिनियेचर आल्टिट्यूड चैम्बेर्स मे परीक्षा केलिए नियामकों को चलाया था।\nЯ был единственным, кто... э-э-э... вообще когда-либо регулировал приборы при проведении тестов в миниатюрных барокамерах.\nNi mimi pekee nilikua nayo uh, ushawahi endeleza dhibiti za majaribio katika chumba cha 'miniature altitude'.\nฉันเป็นเพียงคนเดียวที่เอ่อ เรียกใช้หน่วยควบคุมสำหรับการทดสอบในหอสูงขนาดเล็ก\nMinyatür irtifa odalarındaki test için regülatörleri çalıştıran tek kişi bendim.\nMai akela hi tha jisne kabhi chote altitude chambers test karne kai lye regulators ko run kiya tha.\nTôi là người duy nhất mà uh, từng điều hành các bộ điều chỉnh cho ờ, cho phần thử nghiệm trong các buồng áp lực cỡ nhỏ.\n我是唯一一个，呃，在微型的高度室里做测试一直运行调节器的人。\nوأنا ، رئيس رقباء ، متقاعد ، كما قال ريك.\nАз съм главен сержант, пенсиониран, както каза Рик.\nIch bin äh, Chief Master Sergeant, im Ruhestand, wie Rick sagte.\nΕγώ, κύριε αρχιλοχία, συνταξιοδοτήθηκα, όπως είπε ο Rick.\nI'm uh, Chief Master Sergeant, retired, as Rick said.\nSoy, esto, el sargento jefe maestro, jubilado, como dijo Rick.\nJe suis euh, grand sergent chef, à la retraite, comme Rick l'a dit.\nमै , चीफ मास्टर सार्जेंट हु, विरत, जैसे रिक ने कहा ।\nЯ, это, старший мастер-сержант в отставке, как Рик сказал.\nMstaafu sergenti mkuu, kama alivyosma Rick.\nผมมียศเป็นพันจ่าอากาศเอก ตอนนี้เกษียณแล้ว ตามที่ริคบอกนั่นแหละ\nRick’ in söylediği gibi ben Şef Master Sergeant emekli oldum.\nمیری جنس بہت دلچسپ ہے\nTôi vâng, Thượng Sỹ, đã nghỉ hưu, như Rick đã nói.\n“我是呃，首席军士长，退休了。”里克说。\nاعتادت كراني أن تحكي القصة عن اختها وزوج اختها كيف قررا الإنتقال إلى المدينة ، إلى أوغستا ـ وينتقلوا إلى البيض .\nА баба разказваше историята за това как сестра ѝ и съпругът на сестра ѝ решили, че ще се преместят в града, в Огъста, и ще минават за бели.\nUnd Oma erzählte uns immer die Geschichte in der Ihre Schwester und deren Mann entschieden das sie in die Stadt Augusta ziehen würden und als Weiß durchgehen würden.\nΚαι η Γιαγιά είπε την ιστορία για το πώς η αδελφή της και ο σύζυγος της αδερφής της αποφάσισαν ότι έπρεπε να μετακομίσουν στην πόλη, στην Αουγκούστα και να περάσουν για λευκούς.\nAnd Granny used to tell the story about how her sister and her sister's husband decided that they were going to move to the city, to Augusta, and pass for white.\nY la abuela solía contar la historia sobre cómo su hermana y el marido de esta decidieron que se iban a trasladar a la ciudad, a Augusta, y hacerse pasar por blancos.\nEt Grand-Mère racontait souvent l'histoire de sa sœur et de son beau-frère, qui avaient décidé de s'installer en ville, à Augusta, et de passer pour Blancs.\nऔर दादी वो कहानी सुनाती थि कि कैसे उनकी बहन और उनके पति ने अगस्ता जाने का और सफेद होनेका निर्णय किया\nИ бабушка рассказывала историю о том, как ее сестра и муж ее сестры решили, что они собираются переехать в город, к Августе, и лично считаться белыми\nNyanyangu alikuwa na mazoea ya kusema hadithi ya jinsi dadake na bwana wa dadake walivyoamua wataenda mjini Augusta na wjifanye wazungu.\nและย่าเคยเล่าเรื่องเกี่ยวที่น้องสาวของเธอและสามีของน้องสาวได้ตัดสินใจว่าพวกเขาจะย้ายไปที่เมืองออร์กัสต้าและมองว่าสีขาว\nVe Büyükanne kız kardeşi ve kız kardeşinin kocasının şehre, Agusta'ya taşınmaya ve beyaz olarak kabul edilmeye nasıl karar verdikleriyle ilgili hikayeyi anlatırdı.\nاور نانی نے اس قصہ کو بتانے کے لئے استعمال کیا تھا کہ اس کی بہن اور اس کی بہن کے شوہر نے فیصلہ کیا کہ وہ شہر میں جا رہے تھے،آسٹا کے لئے، اور سفید کے لئے منتقل.\nVà Bà thường kể câu chuyện về việc chị gái của bà và người chồng đã quyết định sẽ chuyển đến thành phố Augusta và chuyển sang màu trắng.\n老太太以前常说她姐姐和姐丈是如何决定要搬到奥古斯塔城里去，并且被当做白人看待。\nهناك بعض التوقعات للتدفق النقدي على مكتبي ، وأوه ، أه ، لمثل هذا كاتي ، هذا هو اسم العميل.\nИма някакви прогнози за паричните потоци на бюрото ми и, е, за някой си Cutty, това е името на клиента.\nEs gibt einige Cash-Flow-Projektionen auf meinem Schreibtisch und, ähm, es ist für so und so Cutty, das ist der Name des Kunden.\nΥπάρχουν μερικές προβλέψεις μετρητών στο γραφείο μου και είναι μια για μια Cutty, αυτό είναι το όνομα του πελάτη.\nThere's some cash flow projections on my desk and, um, uh, it's for such and such Cutty, that's the client's name.\nHay algunas proyecciones del flujo de caja y, uh, eh, es para tal y cual Cutty, el nombre del cliente.\nJ'ai des prévisions de flux de trésorerie sur mon bureau et, ah oui, c'est pour Cutty machin-chose, c'est ça le nom du client.\nमेरे मेंज पर कुछ पैसे की गति की रुपरेखा की प्रोजेक्शन है और वह ये और ये कुट्टी के लिए है, जो की क्लाइंट का नाम है ।\nУ меня на столе лежит некий прогноз движения денежных средств, и он предназначен для такой-то и такой-то Катти - это наименование клиента.\nKuna makadirio ya mtiririko wa fedha kwenye dawati yangu na, um, uh, ni kwa na kama vile Cutty, ndiyo jina la mteja.\nมันมีการประมาณกระแสเงินสดบนโต๊ะฉัน และสำหรับ Cutty นั่นเป็นชื่อของลูกค้า\nMasamda bazı nakit akışı projeksiyonları var ve bu da Cutty için, müşterinin adı.\nمیری میز پر کچھ کاش فلو کے تخمينيں ہيں اور وہ  کسی کٹی کے لۂے ہیں، وہ مؤکل کا نام ہے۔\nCó một số dự đoán về dòng tiền trên bàn làm việc của tôi và, ừm, nó dành cho một khách hàng tên Cutty.\n我的桌子上有一些现金流的推算，嗯，呃，是这个，这个Cutty，这是客户的名字。\nالفتاة التي يمكن أن تساعدني موجودة في جميع أنحاء المدينة.\nМомичето, което може да ми помогне, е на другия край на града.\nDas Mädchen, das mir helfen kann, ist auf der anderen Seite der Stadt.\nΤο κορίτσι που μπορεί να με βοηθήσει είναι στον δρόμο προς την πόλη.\nThe girl that can help me is all the way across town.\nLa chica que me puede ayudar está en la otra punta de la ciudad.\nLa fille qui peut m'aider est à l'autre bout de la ville.\nजो लड़की मेरी मदद कर सकती है वह पूरे शहर भर में है।\nДевушка, которая может мне помочь, находится на другой стороне города\nHuyu msichana anaweza kukusaidia kuenda popote utakapo mjini.\nผู้หญิงที่ช่วยฉันได้มีอยู่ทั่วเมือง\nBana yardım edebilecek olan kız şehrin diğer ucunda.\nجو لڑکی میری مدد کرسکتا ہے وہ شہر بھر میں ہے\nCô gái có thể giúp tôi khắp cả thị trấn.\n能帮助我的女孩在小镇的另一边。\nمايكل سانتو، من شركة فايروال وشركاه، في بوفالو، نيويورك، كانت تلك التي صنعتها، أه، آه، اخترعت منظم O2 العالي قبل أن يبنيوا النار على الموقد بشكل جيد.\nМайкъл Санто от, от Файъруел и Компания, от Бъфало, Ню Йорк, те бяха тези, които, ъъ, изобретиха високият О2 регулатор, преди това те построиха също и пожарния контрол за печки.\nMichael Santo von Firewell and Company aus Buffalo, New York, war derjenige, der äh, den Regler für hohe O2-Werte hergestellt, äh, erfunden hat, bevor die Feuerkontrolle auf dem Herd gut gebaut wurde.\nΟ Michael Santo, της Firewell and Company, από το Μπάφαλο, της Νέα Υόρκης, ήταν εκείνος που, κατασκεύασε, αχμ.. εφηύρε τον υψηλό ρυθμιστή Ο2 πριν αναπτύξουν τον έλεγχο της φωτιάς στη σόμπα.\nMichael Santo of, of Firewell and Company, of Buffalo, NY, they was the one that uh, manufactured, uh, invented the high O2 regulator before that they built fire control on the stove well.\nMichael Santo of, of Firewell and Company, of Buffalo, NY, they was the one that uh, manufactured, uh, invented the high O2 regulator before that they built fire control on the stove well.\nMichael Santo, de Firewell et Company, de Buffalo, New York, était celui qui, a fabriqué, inventé le régulateur à haute teneur en O2 avant de pouvoir bien contrôler le feu sur le poêle.\nबफलो, एन् वई, फयरवाल आँड कंपनी के मैकल सान्डो, वो है जो O2 नियामक बनाया या आविष्कार किया उसके पहले वो चूल्हे पर आग नियंत्रण बनाया ।\nМайкл Санто из компании Firewell and Company, это в Баффало, штат Нью-Йорк. Это они произвели, то есть, изобрели кислородный регулятор выского давления, а до этого сконструировали пожарный датчик для вытяжек.\nMichael Santo, wa Firewell na Kampuni, wa Buffalo, NY, ndiye ambaye alifanya, aliunda,  mdhibiti mkuu wa O2 kabla ya kujenga udhibiti wa moto kwenye jiko vizuri.\nMichael Santo of of Firewell and Company of Buffalo, NY พวกเขาเป็นผู้ผลิต ซึ่งเป็นผู้คิดค้นเครื่องควบคุม O2 ระดับสูงก่อนที่พวกเขาจะสร้างสิ่งควบคุมไฟบนเตา\nFirewell ve Company'den Michael Santo, NY'den Buffalo ocakta ateş kontrolünü yapmadan önce yüksek O2 regülatörünü icat ettiler.\nمائیکل سینٹو آف فائر ویل اور کمپنی کے بفیلو، وہ یہ تھے جہوں نے وہ، اوہ، تیار، اوہ نے اعلی O2 ریگولیٹر کا ایجاد کرنے سے پہلے کہ انہوں نے چولہے پر آگ قابو کرنے والا بنایا۔\nMichael Santo, của Firewell và Company, của Buffalo, NY, họ là một trong những nhà sản xuất, uh, đã phát minh ra bộ điều chỉnh O2 trước khi họ chế tạo điều khiển hỏa lực trên bếp lò.\n迈克尔·桑托来自纽约布法罗市消防公司，他们就是那个发明并生产了高氧气调节器的公司，在那之前他们在炉子上安装了防火装置。\nلكن كانوا منقسمين مثل من كان المزارعين ومن كان أطفال المنزل ، كان ذلك نوع من --\nНо те бяха разделени кой ще работи на полето и кои са децата в къщата, беше малко\nAber sie waren gespalten darüber, wer die Feldhände waren und wer die Hauskinder waren, es war irgendwie ...\nΑλλά χωρίστηκαν στο ποιοι ήταν τα βοηθητικά χέρια και ποιοι ήταν τα παιδιά του σπιτιού, ήταν κάτι τέτοιο -\nBut they were divided about like who were the field hands and who were the house kids, it was kind of--\nPero ellos estaban divididos acerca de quiénes eran las manos de campo y quiénes eran los niños de la casa, era algo así como...\nMais ils étaient divisés au sujet de qui étaient les gens de terrain et qui étaient les enfants de la maison, c'était plutôt ...\nपर वे अब भी आपस में बटे हुए थे की कौन है जो खेती संभालेगा और कौन है जो घर के बच्चे बनेंगे, वो कुछ अजीब सा था --\nНо у них были разногласия по поводу того, типа, кто будет работать в поле, а кто будет работать по хозяйству, и было как-то...\nwalipingana kuhusa ni akina nani walikuwa vijana wa mikono na ni akina nani walikuwa vijana wa kushinda nyumbani. Ilikuwa...\nแต่พวกมันถูกแบ่งประมาณว่าใครเป็นลูกจ้างซึ่งทำงานในท้องนาและใครเป็นคนทำงานบ้าน มันเป็นประเภทของ--\nAma onlar, tarla çocukları ve ev çocukları olarak bölündüler, bir çeşit ...\nلیکن وہ اس طرح تقسیم ہوئے تھے جیسے فیلڈ ہاتھ تھے اور جو گھر کے بچوں تھے،یہ قسم کی تھی --\nNhưng chúng tôi lúc đó chia thành hai nhóm kiểu như ai là kẻ làm lụng kiếm tiền và ai là kẻ ngồi nhà ăn bám, việc này cứ kiểu như--\n这些孩子被按农工，按大院孩子分类，就像是--。\nلكنه كان، كما تعلمون ، بالكثير من الطرق ، تماما مثل ابن مالك مزرعة لأنه كان ابن هذا الرجل الذي يملك الكثير من الممتلكات.\nНо той беше в много отношения като син на собственик на плантация, защото беше син на този човек, който притежаваше много имоти.\ner aber war, weißt du, in vielerlei Hinsicht, wie nur ein Sohn eines Plantagenbesitzers  weil er der Sohn von diesem Kerl war,der viel Grundbesitz besaß.\nΑλλά ήταν, ξέρεις, με πολλούς τρόπους, σαν τον γιο του ιδιοκτήτη της φυτείας επειδή ήταν γιος αυτού του άντρα και διέθετε μεγάλη περιουσία.\nBut he was, you know, in a lot of ways, just like a plantation owner's son because he was the son of this guy who owned a lot of property.\nPero, ya sabes, él era, en muchos sentidos, como el hijo del dueño de una plantación porque era el hijo de este hombre que tenía muchas propiedades.\nMais il était, vous le savez, à bien des égards, tout comme le fils d'un propriétaire de plantation, parce qu'il était le fils de ce type qui possédait beaucoup de biens.\nलेकिन वह था, आप कई तरह से जानते हैं, जैसे बागान मालिक के बेटे की तरह, क्योंकि वह उस व्यक्ति का बेटा था, जो बहुत अधिक सम्पत्ति का मालिक था।\nНо он был во многом, как-бы, всё равно что сын плантатора, так как являлся сыном человека, у которого было в собственности много чего.\nLakini alikuwa, unajua, kwa njia nyingi, kama mwana wa mmiliki wa mimea kwa sababu alikuwa mwana wa mtu huyu ambaye alikuwa na mali nyingi.\nแต่คุณก็รู้ เขาดูเหมือนกับลูกชายของเจ้าของสวนในหลาย ๆ ด้าน เพราะว่าเขาเป็นลูกของชายคนนี้ที่เป็นเจ้าของอสังหาริมทรัพย์มากมาย\nAma biliyorsun ki o pek çok bakımdan tarla sahibinin oğlu gibiydi çünkü o pek çok mülkün sahibi olan bu adamın oğluydu.\nلیکن وہ تھا، آپ جانتے تھے، بہت سے طریقوں میں، صرف ایک پودے کے مالک کا بیٹا کی طرح ہے کیونکہ وہ اس آدمی کا بیٹا تھا جو ملکیت کی ملکیت ہے.\nNhưng anh ta, bạn biết đấy, theo nhiều cách, giống như một người con trai của chủ đồn điền bởi vì anh ta là con trai của người đàn ông  sở hữu rất nhiều tài sản này.\n但他知道，在很多方面，就像种植园主的儿子一样，他的父亲有很多财产。\nسوف يتحدث معنا اليوم عن( اس اس) الثالثة ، )يو٢ ) كويك و بلاك بيرد الطائر الأسود)  .\nДнес той ще ни говори за Третия SS, U2 Quick и Blackbird.\nHeute spricht er zu uns über die Third SS, die U2 Quick und die Blackbird.\nΣήμερα θα μας μιλήσει για την Τρίτη SS, το U2 Quick και το Blackbird.\nToday he is going to talk to us about the Third SS, the U2 Quick and the Blackbird.\nHoy nos hablará sobre la Tercera SS, el U2 Quick y el Blackbird.\nAujourd'hui il va nous parler au sujet de troisième SS, le rapide U2 et l'Oiseau Noir.\nआज वह तीसरा एस एस, यू 2 त्वरित और ब्लैक के बारे में हमें बात करने के लिए जा रहा है।\nСегодня он расскажет нам о Третьей экскадрилье стратегической поддержки, самолетах U2 Quick и Blackbird.\nLeo atatuzungumzia kuhusu majeshi ya Third SS, u2 Quick na Blackbird.\nวันนี้เขาจะพูดคุยกับเราเกี่ยวกับ Third SS, U2 Quick และ Blackbird\nBugün bize, Üçüncü SS, U2 Quick ve Blackbird'den bahsedecek.\nآج وہ ہم سے                                             سوم SS,   U2 Quick اور Blackbird        کے بارے میں  بات کرنے والا ہے۔\nHôm nay anh ấy sẽ nói chuyện với chúng ta về Third SS, U2 Quick và Blackbird.\n今天他要和我们谈谈Third SS，U2 Quick和Blackbird。\nأقصد  كان لديهم خمس أولاد فقط ، مات واحد منهم .\nИскам да кажа, че имаха само пет деца, едно от тях почена.\nIch meine sie hatten nur ungefähr 5 Kinder, eines starb.\nΘέλω να πω ότι είχαν μόνο πέντε παιδιά, ένα από αυτά πέθανε.\nI mean they only had, like, they had five children, one of them died.\nQuiero decir que ellos solo tenían, algo así como, cinco hijos. Uno de ellos murió.\nJe veux dire qu'ils n'ont eu que cinq enfants, l'un d'eux est mort.\nमेरा मतलब है के उनके सिर्फ, जैसे कि, उनके पाँच बच्चे थे, उनमें से एक मर गया।\nУ них было пять детей, один и из которых умер.\nNamaanisha kwamba walikuwa na wana watano, mmoja wao akafa.\nฉันหมายความว่าพวกเขามี ประมาณว่า พวกเขามีลูกห้าคน หนึ่งในนั้นเสียชีวิต\nDemek istediğim, sadece beş çocuğu vardı, biri öldü.\nمیرا مطلب یہ ہے کہ وہ صرف جیسے ہی تھے، ان کے پانچ بچے تھے، ان میں سے ایک مر گیا.\nÝ tôi là họ chỉ có 5 đứa con nhưng một đứa đã mất rồi.\n我的意思是，他们只有五个孩子，其中的一个已经死了。\nوبالطبع لم يُجب أندروف جروميكوف عن أي شيء ولكننا توصلنا إلى المعلومات من التصوير الذي التقطته الوحدة الثانية.\nИ, разбира се, Андрей Громико не отговори нищо, но ние имахме цялата информация от филмите, които U2 беше направил.\nUnd Natürlich Androv Gromikov hat nichts beantwortet, aber wir hatten alle Informationen von den Filmen die U2 gemacht hatte.\nΚαι, φυσικά, ο Androv Gromikov δεν απάντησε τίποτα, αλλά είχαμε όλες τις πληροφορίες από τα φιλμ που είχε πάρει το U2.\nAnd, of course, Androv Gromikov didn't answer anything, but we had all the information from the films the U2 had taken.\nY, por supuesto, Androv Gromikov no respondió a nada, pero disponíamos de toda la información de las películas hechas por el U2.\nEt, bien sûr, Androv Gromikov n'a rien répondu, mais nous avions toute l'information des films que le U2 avait pris.\nऔर, ज़ाहिर है, एंड्रोव ग्रोमिकोव ने कोई जवाब नहीं दिया, लेकिन हमारे पास यू 2 की फिल्मों के आधार पर सारी जानकारी थी।\nКонечно, Андрей Громыко не дал никаких комментариев, но из фильмов мы знаем, что U2 были захвачены.\nNa, kwa kweli, Androv Gromikov hakujibu kitu chochote, lakini tulikuwa na habari zote kutoka kwa filamu za U2 zilizochukuliwa.\nและแน่นอนว่า Androv Gromikov ไม่ได้ตอบอะไร แต่เรามีข้อมูลทั้งหมดจากภาพยนตร์ว่า U2 ที่ใช้ไปแล้ว\nVe tabi ki, Androv Gromikov hiçbir şeye cevap vermedi, ama U2'nin aldığı filmlerden tüm bilgilere sahiptik.\nاور ظاہر ہے ، اندروف گرومیکوف نے کوۂی جواب نہیں دی ، لیکن ہمارے پاس یو ٹو کے فلم کہ وجع سے ساری معلومات تھی\nVà, đương nhiên, Androv Gromikov không trả lời điều gì, nhưng chúng tôi có tất cả các thông tin từ những thước phim do U2 cung cấp.\n当然，安德罗夫格罗米科夫没有回答任何问题，但我们收到了U2拍摄的电影中所有的信息。\nقالت كانت هناك دموع تنهمر من عينيها . وبعد ذلك قالت أن جو قَدِم إلى الشرفة .\nТя каза, че от очите ѝ излизат само сълзи и тогава каза, че Джо се е появил на верандата.\nSie sagte, dass nur Tränen aus ihren Augen kamen und sie sagte, dann sagte sie, Joe kam auf die Veranda.\nΕίπε ότι μόνο δάκρυα έπεφταν από τα μάτια της και είπε, τότε είπε στον Joe να έρθει στη βεράντα.\nShe said there were just tears coming down from her eyes and she said, then she said Joe came up on the porch.\nDijo que le caían lágrimas de los ojos, y después dijo que Joe fue hasta el porche.\nElle a dit qu'il y avait juste des larmes qui descendaient de ses yeux et elle a dit, puis elle a dit que Joe est venu sur le porche.\nवह कह रही थी की सिर्फ आंसू आ रहे थे उसके नयन से  और उसने बताया , फिर उसने बताया जो पोर्च पर आ गया\nОна говорит, что просто у нее из глаз текли слезы, и она говорит, и тут, говорит она, на веранду вышел Джо.\nAlisema kwamba machozi ylikuwa yanamtiririka kutoka usoni. Alafu akasema pia Joe alishuka kutoka ukumbini.\nเธอกล่าวว่ามีน้ำตาไหลออกมาจากตาของเธอ และเธอกล่าวว่าโจมาปรากฏตัวที่ชานบ้าน\nGözlerinde sadece gözyaşları olduğunu söyledi ve Joe'nun verandaya geldiğini söyledi.\nاس نے کہا کے اس کی آنکھیں سے آنسو بہ رہی تھیں ، پھر جو چھتاً پر آگیا\nCô ấy nói chỉ có mình cô ấy là rơi nước mắt và cô ấy nói là, sau đó cô ấy nói rằng Joe xuất hiện trên hiên nhà.\n她说只有眼泪从她眼中流下，她说，然后她说乔在门廊出現。\nحتى لو كانت الطائرة محترقة ، فلماذا تحترق ، وستحترق من خلال عنصر الرصاص حتى يتسرب الإشعاع.\nДори ако самолетът се запали, защо ще изгори и ще се стопи през водещ компонент, за да изтече радиацията.\nSelbst wenn das Flugzeug brannte, warum würde es, es brennen, uh, und es würde durch eine Bleikomponente schmelzen, so dass die Strahlung austreten würde.\nΑκόμα κι αν το αεροσκάφος έπαιρνε φωτιά, θα καιγόταν και θα έλιωνε μέσω ενός βασικού συστατικού για να μην διαρρεύσει η ακτινοβολία.\nEven if the aircraft was on fire, why it, it would uh, burn and it would melt through a lead component for the radiation to leak out.\nIncluso si la aeronave estaba en llamas, ¿por qué se quemaría y se derretiría si tiene un componente de plomo para que la radiación se filtre?\nMême si l'avion était en feu, pourquoi est-ce que, est-ce qu'il euh, brûlerait et fondrait à travers une composante de plomb pour que le rayonnement s'échappe.\nयद्यपि हवाईजहाज जली भी रहेगी, वह जलकर अपने अत्यावश्यक उपकरणों को पिघलकर रेडिएशन को निकलने क्यों देगी ।\nДаже при пожаре в самолете, вряд ли бы он прогорел до того, чтобы расплавить свинцовую оболочку и допустить утечку радиации.\nHata kama ndege ile ilikuwa inateketea, kwa nini, ingekuwa uh, teketea and ingeweza kuyeyuka kwa sehemu ya risasi ili mionzi kuvuja.\nแม้ว่าเครื่องบินกำลังถูกไฟไหม้, ทำไมมัน, มันถึง อึก, ไหม้ เเละ มันอาจจะหลอมละลายส่วนประกอบตะกั่วซึ่งจะทำให้รังสีที่จะรั่วออกมา\nBir uçak alev alsa bile, ki neden yansın, radyasyonun sızması için kurşundan yapılan kısımların erimesi gerekir.\nاگر جہاز کو آگ لگ بھی جاتی ، جل کے ، لیڈ کے جزو سے بہار پگلنا پرتا ، ، اور اس کے بعد ہی تابکاری باہر نکلتا ۔\nNgay cả khi chiếc máy bay đang cháy, thì tại sao, ừ sao nó lại bị cháy và có thể tan chảy xuyên qua một thành phần bằng chì để bức xạ bị rò rỉ.\n即使飞机着火了，它也会燃烧并通过使铅制组件熔化来让辐射泄露出去。\nلكن كان عملي أن أضع المظلات عليها ومنقذي الحياة عند تحميلها يبدأوون التحرك إلى موقع خارجي .\nНо работата ми беше да поставя парашути и животоспасителни уреди, когато щяхме да го заредим и да тръгнем за дестинация в чужбина.\nAber meine Aufgabe war es, Fallschirme und Rettungswesten darauf zu legen, während wir laden und an einen Ort in Übersee gehen würden.\nΑλλά η δουλειά μου ήταν να βάλω τα αλεξίπτωτα και τα σωσίβια, όπου θα τα φορτώναμε και θα ξεκινούσαμε σε μια τοποθεσία στο εξωτερικό.\nBut my job was to put parachutes on it and life preservers when we would load it and start to an overseas location.\nPero mi trabajo era colocar paracaídas y chalecos salvavidas en el momento de cargar y despegar hacia un lugar extranjero.\nMais mon travail était de mettre des parachutes sur ça et des gilets de sauvetages lorsque nous le chargions et quelque part à l'étranger commencions .\nलेकिन मेरा काम इस पर पैराशूट रखने और जीवन के संरक्षण का था जब हम इसे लोड करेंगे और एक विदेशी जगह पर शुरू करेंगे।\nМоя же работа была в том, чтобы укладывать парашюты и спасательные жилеты, когда мы грузились и отправлялись за океан.\nLakini kazi yangu ilikuwa kuweka parashuti juu yake na vyombo vya kuokoa maisha wakati tulipakia mzigo na kuanza kuenda mahali nje ya nchi.\nแต่งานของฉันคือการใส่ร่มชูชีพลงบนมันและอุปกรณ์ที่ใช้สวมเพื่อให้ลอยอยู่เหนือน้ำเมื่อพวกเราต้องขนมันและเดินทางไปยังสถานที่ต่างประเทศ\nAma benim işim onu yüklediğimizde ve denizaşırı bir konuma gitmek üzere yola çıktığımızda onun üzerine paraşütleri ve can yeleklerini koymaktı.\nلیکن میرا کام اس پر پیراشوٹ ڈالنا تھا اور زندگی بچانے والی اشیاء ۔ جب ہم اسے لوڈ کریں گے اور بیرون ملک مقیم جگہ کی طرف سفر شروع کریں گے ۔\nNhưng công việc của tôi là đặt dù trên đó và bảo vệ cuộc sống khi chúng tôi tải nó và bắt đầu đi tới một địa điểm ở nước ngoài.\n但我的工作是在它和救生工具上放降落伞，然后装载它，开始到海外的一个地点。\nهذه هي الطريقة التي، آآ، أحافظ بها على ربط حزامي.\nИ ето така оставам заета.\nSo, äh, ich bleibe dran.\nΤώρα κάπως, παραμένω μπλοκαρισμένος.\nNow that's how, uh, I stay buckled in.\nAsí es como, mmm..., me quedo abrochado.\nC'est comme ça que, maintenant, je reste accroché.\nअब यह कैसे, उह, मैं अंदर झुक रहा हूँ\nНу вот таким образом, эм, я остался пристёгнут.\nSasa hivyo ndivyo jinsi, uh, mimi hufungiwa ndani kwa bakoli.\nตอนนี้นั่นไงล่ะ อือ ฉันใส่หัวเข็มขัดค้างไว้\nŞimdi işte böyle eğri kaldım.\nتو اس طرح میں ، ام، , بیلٹ کے اندر رہتا ہوں۔\nBây giờ đó là cách, uh, tôi ở lại.\n就是我能坚持住的原因。\nهذا هو رئيس الرقيب كليم فرانسيس ، متقاعد من القوات الجوية الأمريكية.\nТова е главен сержант Клем Франсис, пенсиониран от военновъздушните сили на САЩ.\nDies ist Chief Master Sergeant Clem Francis, im Ruhestand von der US Air Force.\nΑυτός είναι ο Αρχιλοχίας Clem Francis, συνταξιούχος από την Πολεμική Αεροπορία των ΗΠΑ.\nThis is Chief Master Sergeant Clem Francis, retired from the US Air Force.\nEste es el sargento jefe maestro Clem Francis, jubilado de las fuerzas aéreas estadounidenses.\nVoici le sergent-major chef Clem Francis, retraité de l'US Air Force.\nयह अमेरिकी वायु सेना से सेवानिवृत्त मुख्य मास्टर सार्जेंट क्लेम फ्रांसिस है।\nЭто Главный Мастер сержант Клэм Фрэнсис, ушедший в отставку из военно-воздушных сил США.\nMkuu wa Jeshi Clem Francis astaafu kutoka kitengo cha majeshi ya hewani ya Marekani.\nนี่คือพลอากาศเอกเคลม ฟรานซิส เกษียณมาจากกองทัพอากาศอเมริกา\nBu, ABD Hava Kuvvetleri'nden emekli olan Astsubay Kıdemli Başçavuş Clem Francis'tir.\nیہ صردار ماسٹر  افسر کلیم فرانسس ، امریکی ایر فورس سے رٹاۂر کر چکيں ہیں۔\nĐây là Trung sĩ trưởng Clem Francis, đã nghỉ hưu từ Không quân Hoa Kỳ.\n这是士官长Clem Francis，原空军少校，他已经从美国空军退役。\nحسنا حدث أن  طائرتين  أو ثلاثة ينبغي أن تصل خلال اسبوع ولم أعرف أين وجهتهما .\nАми, стигна се до момент, когато имаше два или три пристигащи самолета на седмица, и аз не знаех накъде летят.\nNun es kam soweit, dass zwei oder drei Flugzeuge pro Woche ankamen und ich nicht wusste, wohin sie flogen.\nΛοιπόν πήγαμε εκεί όπου φτάνουν δύο ή τρία αεροσκάφη την εβδομάδα και δεν ήξερα με ποιο πετάνε.\nWell it got to where there's two or three aircraft arrive in a week and I didn't know where they're flying to.\nBueno, llegó hasta el punto en que hay dos o tres aeronaves que llegan en una semana y no sabía a dónde volaban.\nEh bien, c'est arrivé là où deux ou trois appareils arrivent en une semaine et je ne savais pas où ils volaient.\nफिर ऐसे होने लगा की हफ्ते में दो से तीन हवाईजहाज आने लगी और कहा जा रही है इसका पता ही नहीं था मुझे |\nВ общем, дошло до того, что было по одному-два самолёту в неделю и никто не знал, куда они летят.\nBasi ilifikia mahali kuna ndege mbili ama tatu zinazofika kila wiki na sikujua kwenye zinapaa zikienda.\nคือมันไปถึงจุดที่เครื่องบินสองสามลำบินเข้ามาต่อสัปดาห์ และฉันไม่รู้ว่าพวกมันปิดไปไหนกัน\nBir haftada iki ya da üç uçağın vardığı yere ulaştı ve nereye uçtuklarını bilmiyordum.\nنوبت یہاں تک پہنچ گئی ہے کہ ایک ہفتے میں دو یا تین ہوائی جہاز آتے ہیں اور مجھے نہیں پتہ کہ وہ اڑ کر کہاں جا رہے ہیں.\nVâng, có lúc có hai hoặc ba máy bay đến trong một tuần và tôi không biết họ đang bay đi đâu.\n那么它就到了一两周内有两架或三架飞机到达的地方，而我不知道他们飞往哪里。\nلقد تلقوا تدريبهم من قبل بملابس الضغط الكاملة واخذت مني وقتا إذا ارتديت بدلات الضغط الكاملة .\nТе вече бяха преминали обучението си във височинните костюми, а отнема известно време, ако облечете височинните костюми.\nSie hatten ihr Training bereits in den vollen Druckanzügen absolviert und ich brauchte eine Weile, den vollen Druckanzug anzuziehen.\nΕίχαν ήδη την εκπαίδευσή τους στις στολές πλήρους πίεσης και μου πήρε λίγο χρόνο να βάλω την στολή πλήρους πίεσης.\nThey had already had their training in the full pressure suits and it taken me a while if you go into full pressure suits.\nEllos ya habían tenido su entrenamiento en los trajes presurizados y me había costado algo de tiempo ponerme un traje presurizado.\nIls avaient déjà eu leur entraînement pour les combinaisons pressurisées et ça m'a pris du temps si tu mets une combinaison pressurisée complète.\nउन्होंने पहले ही अपने प्रशिक्षण दबाव वाले सुटो में कर लिया था और मुझे कुछ वक़्त हो चूका था दबाव वाले सुटो को पहनकर ।\nОни уже тренировались в высотно-компенсирующих костюмах, и мне потребовалось некоторое время, чтобы надеть такой костюм.\nNa tayari walikuwa na mazoezi yao katika mkazo mkuu wa suti na imenichukua muda kama utaenda kwa mkazo mzima wa suti.\nพวกเขาได้รับการฝึกอบรมในชุดสูทความดันอากาศและมันทำให้ฉันใช้เวลาสักพักนึงถ้าคุณต้องใส่ชุดสูทความดันอากาศนั้น\nTam baskı takımlarında eğitimlerini çoktan hazırlamışlardı ve tam baskı takımlarına giderseniz benim bir süremi aldı.\nانہوں نے پہلے سے ہی مکمل دباوں کے لباس میں ٹرینينگ کر لی تھی ، مجھے تھوری دیر لگی۔\nHọ đã luyện tập trong bộ com lê đầy áp lực và mất một thời gian nếu bạn mặc bộ quần áo đầy áp lực.\n他们已经穿上全压力服装进行训练，如果你穿全压力服，我会花上一段时间帮你。\nوكان زير نساء ، و نعم ، وكان كأنه في الخارج هناك . و، آه ، هكذا ، أن تعرف ، لم أحبه ، لكن على أية حال تلك هي حكاياتي.\nА той беше любовчия, да, такъв беше. И, така, знаете ли, аз не го харесвам, но така или иначе това са моите истории.\nUnd er war ein Schürzenjäger, und ja, er war dort draußen. Und naja, ich mochte ihn nicht, aber was soll es, dies sind meine Geschichten.\nΚαι ήταν ένας φιλάνθρωπος, ναι ναι, έτσι ήταν εκεί έξω. Και, αχ, έτσι, ξέρεις, δεν μου άρεσε, αλλά τέλος πάντων είναι οι ιστορίες μου.\nAnd he was a philanderer, an d oh yeah, he was like out there . And, ah, so, you know, I didn't like him, but anyway those are my stories.\nY él era un mujeriego, y oh sí, estaba como ahí fuera. Y entonces ya sabes, no me gustó, pero de todos modos estas son mis historias.\nEt il était un coureur de jupon, et oh ouais, il était comme dehors. Et, ça, alors, tu sais, je ne l'aimais pas, mais de toute façon ce sont mes histoires.\nऔर हा वह छिछोरा भी था, और हां जी हां वो जैसे वही था . और हा , इसी लिए बता रहा हु ताकि आप समजे , मुझे वह पहले से ही पसंद नहीं था , खैर ऐसे बोहोत कहानिया है।\nИ он был сторонником, да, да, он действительно им являлся. И вы знаете, мне он не нравился, но все равно это мои истории.\nAlikuwa mshambuliaji wa watoto na alipenda kuzurura. Sikumpenda lakini wajua hizo ndio hadithi zangu.\nและเขาเป็นคนเจ้าชู้ และ โอใช่แล้ว เขาเคยเป็นแบบนั้น และอา ดังนั้นเธอก็รู้นี่ว่าฉันไม่ชอบเขา แต่อย่างไรก็ตามเรื่องพวกนั้นก็เป็นเรื่องราวของฉันเอง\nVe o bir zamparaydı ve ya evet orada gibiydi. Ve yani biliyorsun ben onu sevmezdim ama her neyse bunlar benim hikayelerim.\nاور وہ دل پھینک تھا اور ہاں وہ شدت پسند تھا اور  اور ہاں تمہیں بتاتی چلوں کے مجھے وہ پسند نہیں تھا لیکن خیر وہ سب میری کہانیاں ہیں۔\nVà anh ta là một người đi lang thang, và à phải, anh ấy giống như ở ngoài kia vậy. Và, à, vì vậy, bạn biết đấy, tôi không thích anh ấy, nhưng dù sao đó cũng là chuyện của tôi.\n而他是一个慈善家，噢对了，他就像在外面一样。而且啊，所以，你知道吗，我并不喜欢他，但无论如何，那就是我的故事。\nعندما كنت أشد ، وعندما كان هو يشد قماش المظلة لي لكي أبدأ بإخراجه ، وكان يشير إلى جهازين على الجانب الأيسر من الطائرة واللذان كانا قد احترقا فعلا أثناء الطيران .\nКогато дърпам, когато той дърпа навеса, за да започна да го измъквам, той посочва два инструмента от лявата страна на самолета, които всъщност са се разтопила по време на полета.\nAls ich zog, als er den Schirm zieht, um ihn herauszuholen, zeigt er auf zwei Instrumente auf der linken Seite des Flugzeugs, die während des Fluges tatsächlich geschmolzen waren.\nΌταν τραβήξω, όταν τραβήξει το κουβούκλιο για μένα για να ξεκινήσω να τον βγάλω έξω, δείχνει δύο όργανα στην αριστερή πλευρά του αεροσκάφους που είχαν λιώσει πραγματικά κατά τη διάρκεια της πτήσης.\nWhen I pull, when he pulls the canopy for me to start getting him out, he points to two instruments on the left side of the aircraft that had actually melted during the flight.\nCuando tiro, cuando tira del toldo para que empiece a sacarle, señala dos instrumentos en el lado izquierdo de la aeronave que se habían derretido durante el vuelo.\nLorsque je tire, lorsqu'il tire sur la verrière pour que je commence à le sortir de là, il montre deux instruments sur la gauche de l'avion qui avaient en fait fondu durant le vol.\nजब मैं खींचता हूं, जब उसे बाहर निकालने के लिए वह मेरे लिए चंदवा खींचता है, वह विमान के बाईं ओर दो उपकरणों को इंगित करता है जो वास्तव में उड़ान के दौरान पिघले थे।\nКогда я тяну, когда он вытаскивает тент, чтобы я начал его вытаскивать, он указывает на два инструмента на левой стороне самолета, которые действительно расплавились во время полета.\nWakati ninapovuta, wakati anapovuta kamba ili mimi kuanza kumtoa nje, anaonyesha vyombo viwili upande wake wa kushoto wa ndege ambayo kwa kweli imeyeyuka wakati wa kukimbia/ kupaa.\nเมื่อฉันดึง เมื่อเขาดึงกระโจมเพื่อให้ฉันดึงเขาออกไป เขาชี้ไปที่อุปกรณ์ทั้งสองด้านซ้ายมือของเครื่องบินที่ละลายจริง ๆ ระหว่างการบิน\nÇekerken, kanopiyi onu çıkarmaya başlamak için çektiğinde, uçağın sol tarafında aslında uçuş sırasında eriyen iki enstrümanı işaret ediyordu.\nجب میں کھیج\nKhi tôi kéo, khi anh kéo cái tán cho tôi để bắt đầu đưa anh ta ra, anh ta chỉ vào hai thiết bị ở phía bên trái của chiếc máy bay đã chảy trong suốt chuyến bay.\n当我拉的时候，当他为我拉座舱盖，让我开始把他拉出来时，他指着飞机左侧在飞行中实际上已经融化了的两个仪器。\nو هي حقا لم تفهم .\nИ тя наистина не разбираше.\nUnd sie hat es nicht wirklich verstanden.\nΚαι, δεν κατάλαβε πραγματικά.\nAnd, she didn't really understand.\nY en realidad no lo entendió.\nEt, elle n'avait pas vraiment compris.\nऔर उसने सच मे नहीं सम्झा\nИ, на самом деле, она не поняла.\nAlafu, hakuelewa kabisa.\nและเธอไม่เข้าใจจริง ๆ\nVe o gerçekten anlamadı.\nاور، وہ واقعی نہیں سمجھی۔\nVà, cô ấy đã thực sự không hiểu.\n她不太明白。\nأريد أن أقول إنه لم يكن هناك أي خطر في الدخول في القنبلة لأنها لن تنفجر، بغض النظر عن مدى قوتها على الأرض.\nИскам да кажа, че нямаше никаква опасност от влизане с бомбата, защото нямаше да избухне, независимо колко силно удари земята.\nIch möchte sagen, dass es keine Gefahr gab mit der Bombe in die Luft zu gehen, weil sie nicht explodieren würde, egal wie hart sie auf dem Boden aufschlug.\nΘέλω να πω ότι δεν υπήρχε κανένας κίνδυνος να ξεκινήσω με τη βόμβα επειδή δεν θα εκρηγνύονταν, ανεξάρτητα από το πόσο δύσκολα χτύπησε στο έδαφος.\nI want to say that there wasn't any danger of going in with the bomb because it would not explode, regardless of how hard it hit the ground.\nQuiero decir que no había ningún peligro de entrar con la bomba porque no explotaría, independientemente de lo duro que tocó el suelo.\nJe veux dire qu'il n'y avait aucun danger à aller à l'intérieur avec la bombe parce qu'elle n'aurait pas explosé, quelle que soit la force avec laquelle elle avait heurté le sol.\nमैं यह कहना चाहता हूं कि वहां बम के साथ जाने का कोई खतरा नहीं है क्योंकि इसमें विस्फोट नहीं होगा, इसके बावजूद कि यह जमीन को कितनी जोर से टक्कर करता है।\nЯ хочу сказать, что опасности, если бы вошел с бомбой, не было, потому что она бы не взорвалась, как бы сильно она ни ударилась о землю.\nNataka kusema ya kwamba hakuwa na hatari yoyote ya kuingia na bomu kwa sababu haiwezi kulipuka, licha ya vile ingeanguka kwenye ardhi.\nฉันต้องการจะบอกว่ามันไม่มีอันตรายใดๆทั้งนั้นในการจะเข้าไปพร้อมกับระเบิดเพราะมันจะไม่ระเบิดไม่ว่ามันจะตกกระแทกพื้นแรงแค่ไหนก็ตาม\nBomba ile içeri girme tehlikesi olmadığını söylemek istiyorum çünkü yere ne kadar sert vurursa vursun, patlamayacaktı.\nمیں کہنا چاھتا ہوں ، کے کوۂی خطرہ نہیں تھا ، کیونکہ اگر بم زمین پر ضور سے گر بھی جاتا ، تب ھی نہیں پھٹتا۔\nTôi muốn nói rằng không có nguy cơ nào xảy ra với quả bom vì nó sẽ không nổ tung, bất kể nó khó đến thế nào.\n我想说，接近炸弹不会有危险，因为它不会爆炸，无论你多么用力的摔它。\nوما رأيك أن الأمر يبدو على نحو مماثل لما أحاول القيام به بالفعل.\nА какво ако това изглежда точно така, както се опитвам да го направя.\nUnd wie wäre es, es sieht genau so aus, wie das, was ich versuche zu tun.\nΚαι μοιάζει ακριβώς σαν αυτό που προσπαθώ να κάνω.\nAnd how about it looks like exactly what I'm trying to do.\nY qué tal si se parece exactamente a lo que intento hacer.\nEt en quoi ça ressemble exactement à ce que j'essaie de faire.\nऔर यह  बिल्कुल वैसा लग रहा है जैसा मैं इसे करने की कोशिश कर रहा हूं\nЧто, если я скажу, что выглядит это в точности как то, что я пытаюсь сделать.\nNa jinsi gani inavyofanana hasa na ninachojaribu kufanya.\nและอันนี้มันเหมือนกันกับสิ่งที่ฉันพยายามที่จะทำ\nVe ne yapmaya çalıştığım tam olarak nasıl görünüyor.\nاور اس طرح کی طرح ایسا لگتا ہے کہ میں وہی کرنے کی کوشش کر رہا ہوں.\nNó trông đúng như điều tôi đang cố gắng thực hiện.\n这看起来和我想做的完全一样。\nلكن على أية حال ، ستصبح الحيوانات طليقة طوال الوقت ، وخاصة الماعز .\nНо така или иначе животните през цялото време се отвързват, особено козите.\nAber trotzdem laufen die Tiere ständig davon, besonders die Ziegen.\nΑλλά τέλος πάντων, τα ζώα χαλαρώνουν όλη την ώρα, ειδικά οι κατσίκες.\nBut anyway, the animals would get loose all the time, especially the goats.\nPero de igual forma, los animales se perderían todo el tiempo, en especial las cabras.\nMais de toutes façons, les animaux se sauveraient tout le temps, surtout les chèvres.\nखैर जो भी हो, जानवर सरे हर भर खुल जाते थे, खास करके बकरिया ।\nНо все равно животные все время отвязывались, особенно козы.\nLakini hata hivyo, wanyama wangekuwa huru wakati wote, hasa mbuzi.\nแต่ถึงอย่างไรก็ตาม สัตว์น่าจะเสียเวลาทั้งหมด โดยเฉพาะแพะ\nAma yine de, hayvanlar her zaman, özellikle keçiler kaçıp kaybolurlardı\nلیکن خیر، جانور ہر وقت چھوٹ جاتے تھے، خاص طور پر بکریاں.\nNhưng dù sao đi nữa, những con vật sẽ bị mất toàn bộ thời gian, đặc biệt là những con dê.\n无论如何，动物总是能摆脱防守，尤其是山羊。\nربما أخبرت كل شخص آخر ولم أكن لأنتبه في ذلك الوقت بالتحديد .\nМоже би е казала на всички останали, но не обръщах внимание в този момент.\nVielleicht hat sie es allen anderen erzählt und ich habe zu dem Zeitpunkt nicht aufgepasst.\nΊσως μίλησε σε όλους τους άλλους και δεν έδωσα προσοχή τη συγκεκριμένη στιγμή.\nMaybe she told everyone else and I wasn't paying attention at that particular time.\nTal vez ella le dijo a todos los demás y yo no estaba prestando atención en ese momento en particular.\nPeut-être qu'elle l'a dit à tous les autres et que je ne faisais pas attention à ce moment précis.\nहो सकता है कि उसने सभी को बताया हो और मैं उस विशेष समय पर ध्यान नहीं दे रहा था।\nВозможно, она сказала всем остальным, а я в этот конкретный момент не обратил внимания.\nLabda aliambia kila mtu na sikuwa natilia maanani kwa wakati huo.\nบางทีเธอบอกคนอื่น ๆ และฉันไม่ได้ใส่ใจในช่วงเวลานั้น ๆ\nBelki de herkese anlattı ve o zaman dikkat etmedim.\nشاید اس نے ہر کسی کو بتایا اور میں اس مخصوص وقت پر توجہ نہیں دے رہا تھا\nCó lẽ cô ấy đã nói với mọi người và tôi không chú ý vào thời điểm cụ thể đó.\n或许她告诉了所有人，那时候我根本就没注意。\nتم قفل الأبواب عندما دخلنا.\nВратите бяха заключени, когато влязохме.\nDie Türen waren geschlossen, als er hineinging.\nΟι πόρτες κλείδωσαν όταν μπήκαμε.\nThe doors were locked when we went in.\nCuando entramos, las puertas estaban cerradas.\nLes portes étaient fermées lorsque nous sommes entrés.\nजब हम अंदर गए तो दरवाजा बंद था।\nКогда мы вошли, двери заперли.\nMilango ilifungwa wakati tuliingia.\nประตูถูกล็อคเมื่อเราเดินเข้ามา\nİçeri girdiğimizde kapılar kilitliydi.\nجب ہم اندر گئے تو دروازے بند کردیئے گئے\nCánh cửa bị khoá khi chúng tôi bước vào.\n我们进去时门被锁上了。\nخسرنا طائرتين أو ثلاثة عندما كنا هناك ،و،إيه، وكانت مرحلة اختبار .\nЗагубихме саво два-три самолета, докато бяхме там, и, фаза на изпитване.\nNur zwei, drei Flugzeuge verloren, als wir da waren und, äh, Testphase.\nΧάσαμε μόνο δύο, τρία αεροσκάφη ενώ ήμασταν εκεί, και, σε φάση δοκιμής.\nOnly lost two, three aircraft while we was there, and, uh, test phase.\nSolo perdimos dos, tres aviones mientras estábamos allí, y, uh, fase de prueba.\nNous avons seulement perdu deux, trois avions pendant que nous étions là, et, euh, la phase de test.\nकेवल दो, तीन विमान खो गए जब हम वहां थे, और, उह, परीक्षण चरण में भी कुछ ।\nМы потеряли всего пару-тройку самолетов пока там были... ну, этот - как его? - этап испытаний.\nKupoteza tu ndege mbili, tatu wakati tulikuwa huko na, uh, wakati wa kupima.\nหายไปแค่สองลำ, เครื่องบินสามลำในขณะที่เราอยู่ที่นั่น, เอ่อ, ระยะทดสอบ\nOradayken sadece iki, üç uçak ve test aşamasını kaybettik.\nصرف دو کھو دیا،ہم وہاں موجود تین طیارے، اور،ٹیسٹ کا مرحلہ.\nChỉ mất hai, ba chiếc phi thuyền khi chúng tôi ở đây, và, uh, giai đoạn thử nghiệm.\n我们在那里，呃，测试阶段的时候，只丢了两、三架飞机。\nأحتاجك لتفعل شيئا ما من أجلي .\nИмам нужда да направиш нещо за мен.\nIch brauche dass du etwas für mich tust.\nΠρέπει να κάνεις κάτι για μένα.\nI need you to do something for me.\nNecesito que hagas algo por mí.\nJ'ai besoin que tu fasses quelque chose pour moi.\nमेरे लिए आपको कुछ करने कि जरुरत है।\nМне необходимо, чтобы вы сделали кое-что для меня.\nNinahitaji kufanya kitu kwa ajili yangu.\nฉันต้องการให้คุณทำอะไรให้ฉันหน่อย\nBenim için bir şey yapmana ihtiyacım var.\nمجھے آپ سے ایک کام ہے\nTôi cần bạn làm điều gì đó cho tôi.\n我需要你为我做点什么。\nلذا، كان عليّ فقط أخذ المجاميع ومحاولة تصويرها من هذا القبيل.\nТака че просто трябваше да взема общите суми и да опитам да ги разбера по този начин.\nAlso, ich musste einfach die Summen nehmen und es versuchen es so zu machen.\nΈτσι έπρεπε να βγάλω τα σύνολα και να προσπαθήσω να το καταλάβω έτσι.\nSo I just had to take the totals and try and figure it like that.\nDe manera que solo tenía que coger los totales y probar y resolverlo así.\nDonc, j’ai juste eu à prendre les totaux et à essayer de les mettre sous forme graphique.\nतो मैं सिर्फ योग लेना पड़ा और इस तरह यह बताने का प्रयास करना पड़ा ।\nТак что я должен был подвести итоги и попробовать оценить их.\nKwa hivyo nilihitaji tu kuchukua jumla na kujaribu na kuiona kama hiyo.\nดังนั้นฉันจึงต้องเอาผลรวมพวกนั้นและพยายามและก็คำนวนมันแบบนั้น\nBu yüzden toplamları almalı ve bunu böyle denemeli ve ifade etmeliydim.\nتو مجھے کل رقم لے کر حساب کرنا پرا تھا۔\nNên tôi đã phải lấy cái tổng thể và cố gắng hình dung ra như thế đó.\n所以我只需要拿出总数，然后尝试像这样去解决。\nأتى ، فتح الباب وأتذكر أني نظرت إلى الخلف ورأيت تعابير قسمات وجه واستطيع القول أنه كان محبطا .\nТой дойде, отвори вратата, и си спомням, че погледнах назад и видях изражението на лицето му, и можех да кажа, че е разочарован.\nEr kam, er öffnete die Tür und ich erinnere mich daran dass ich zurück schaute und den Ausdruck auf seinem Gesicht sah, und ich wusste, dass er enttäuscht war.\nΉρθε, άνοιξε την πόρτα και θυμάμαι κοίταξε πίσω και είδα την έκφραση στο πρόσωπό του, και θα μπορούσα να πω ότι ήταν απογοητευμένος.\nHe came, he opened the door and I remember looking back and seeing the expression on his face, and I could tell that he was disappointed.\nÉl vino, abrió la puerta y recuerdo mirar atrás y ver la expresión de su rostro, y pude ver que estaba decepcionado.\nIl est venu, il a ouvert la porte et je me souviens d'avoir regardé en arrière et d'avoir vu l'expression sur son visage, je pouvais voir qu'il était déçu.\nवो आया दरवाज़ा खोला और मुझे यादहै पीछे देखते हुए उसके चेहरे पर की भाव देखकर मै यह वादे से कह सकता हु की वो निराश  था ।\nОн пришел, открыл дверь, я оглянулась и увидела выражение его лица - уверена, он был расстроен.\nAlikuja, akafungua mlango na nakumbuka kuangalia nyuma na kuona 'maelezo' juu ya uso wake, na ningeweza kusema kwamba alikuwa amesikitika.\nเขามา, เขาเปิดประตูและผมจำได้ว่ามองย้อนกลับไปและได้เห็นสีหน้าบนใบหน้าของเขา, และผมบอกได้เลยว่าเขารู้สึกผิดหวัง\nO geldi, kapıyı açtı ve geriye bakıp yüzündeki ifadeyi gördüğümü hatırlıyorum ve hayal kırıklığına uğradığını anladım.\nوہ آیا، اس نے دروازہ کھولا اور مجھے یاد ہے میں نے مڑ کر دیکھا اور اس کے چہرے کے تاثرات دیکھ کر میں بتا سکتا تھا کہ وہ مایوس ہو گیا تھا ۔\nAnh ấy đến, mở cửa ra và tôi nhớ rằng đã nhìn ra phía sau và nhìn thấy vẻ mặt của anh ấy, và tôi có thể biết được anh ấy đang thất vọng.\n他来了，打开门，我记得我回头看了看他脸上的表情，我可以说他当时很失望。\nلذلك ، لا أملك أي قصص محددة  .\nТака че, нямам никакви конкретни истории.\nAlso, ich habe keine spezifischen Geschichten.\nΈτσι, δεν έχω συγκεκριμένες ιστορίες.\nSo, I don't have any specific stories.\nEntonces, no tengo ninguna historia concreta.\nDonc je n'ai pas d'histoires spécifiques.\nतो, मेरे पास किसी प्रकार की विशिष्ट कहानियाँ नही हैं।\nИтак, у меня нет особых историй.\nKwa hivyo, sina hadithi maalum.\nดังนั้น, ฉันไม่มีเรื่องใดเรื่องหนึ่ง\nYani, benim özgün hikayelerim yok.\nلہذا، مجھے کوئی خاص کہانی نہیں ہے.\nVậy nên, tôi không có bất cứ câu chuyện cụ thể nào.\n所以，我没有任何具体的故事。\nمختلف، أنواع مختلفة تماما من المظلات وفي طائر يحلق بسرعة، أه، ثلاثة أضعاف سرعة الصوت أي أكثر من 22000 ميل في الساعة.\nРазлични, съвсем различни видове парашути и в птица, която лети с три пъти скоростта на звука, над 22 000 мили на час.\nVerschiedene, völlig verschiedene Arten von Fallschirmen und in einem Vogel, der, äh, dreimal so schnell fliegt wie der Schall, über 22.000 Meilen pro Stunde.\nΔιαφορετικοί, εντελώς διαφορετικοί τύποι αλεξίπτωτων και όπως ένα πουλί που πετάει, τρεις φορές την ταχύτητα του ήχου, πάνω από 22.000 μίλια την ώρα.\nDifferent, totally different types parachutes and in an bird that flies, uh, three times the speed of sound, over 22,000 miles an hour.\nDiferentes, paracaídas de tipos totalmente diferentes y en un pájaro que vuela, esto, tres veces a la velocidad del sonido, más de 22 000 millas por hora.\nDes types de parachutes différents, totalement différents et un oiseau qui vole, euh, à trois fois la vitesse du son, à plus de 22000 miles par heure.\nअलग, पूरी तरह से अलग-अलग प्रकार के पैराशूट और एक पक्षी में जो उड़ता है, उह, आवाज की गति से तीन गुना, एक घंटे २२००० मील से अधिक\nРазные, совершенно различные типы парашютов - и это в птичке, летающей в три раза быстрее скорости звука - более 22 тысяч миль в час.\nKuna aina tofauti za parachuti na laffu pia kuna ndege ambaye hupea kwa mwendo wa maili 22,000  kila saa ambao ni mwendo mara tatu ya mwendo wa sauti.\nเป็นการกระโดดร่มที่ต่างกันโดยสิ้นเชิง โดยเฉพาะอย่างยิ่งบนเครื่องบินที่บินเร็วกว่าเสียงถึงสามเท่า ซึ่งมากกว่า 22,000 ไมล์ต่อชั่วโมง\nFarklı, tamamen farklı paraşütler ve saatte 22.000 milden fazla sesin üç katı olan bir kuş.\nمختلف، مکمل طور پر مختلف قسم کے پیراشوس اور ایک پرندوں میں جو پرواز کرتا ہے، uh، آواز کی رفتار تین گنا، فی گھنٹہ 22،000 میل سے زیادہ.\nKhác, loại dù hoàn toàn khác và trong một chiếc máy bay có tốc độ, uh, nhanh gấp ba lần vận tốc âm thanh, hơn 22,000 dặm một giờ.\n不同的，完全不同类型的降落伞和一只鸟飞行的话，嗯，是声音速度的三倍，超过每小时22000英里。\nكان، إيه ، منا ، إيه،ردولف أندرسن في شكل ثلاث U2's.\nТова беше, ъъ, когато ние, ъъ, имахме Рудолф Андерсън в екипажа на три самолета Ю2.\nEs war hmm, was wir hmm, wir hatten Rudolph Anderson in einer, einer Entwicklung von drei U2s\nΉταν, σαν να είχαμε τον Rudolph Anderson σε έναν σχηματισμό τριών U2.\nIt was, uh, which we, uh, had Rudolph Anderson in a, a formulation of three U2's.\nEra en lo que teníamos a Rudolph Anderson, una formulación de tres U2.\nC'était, euh, ce que nous, euh, avions placé Rudolph Anderson dans une, une formation de trois avions U-2.\n, उह, जो हम, उह,एक में रूडोल्फ एंडरसन था, तीन U2 का एक निर्माण था।\nНу, у нас Рудольф Андерсон разработал три U2.\nIlikuwa, uh, ambayo sisi, uh, ulikuwa na Rudolph Anderson katika, uundaji wa U2 tatu.\nมันคือ, ฮึก, ที่พวกเรา, ฮึก, มี Rudolph Anderson อยู่ใน, การออกกฎของสาม U2\nO zaman, Rudolph Anderson'a üç U2 oluşturma görevini verdik.\nہمارے پاس رودلف ایندرسن تھا ، یو-ٹو کے تین قیام میں\nĐó là, uh, chúng tôi, đã có Rudolph Anderson trong một, một công thức của ba U2.\n是的，呃，我们，呃，U2乐队的构想里包括Rudolph Anderson。\nوفعلا لم يكن عليه أبدا أن يقدم أي شيْ لنفسه .\nТова беше, на него наистина никога не му се беше налагало да прави нещо за себе си.\nUnd so war es, er musste nie wirklich etwas für sich selbst tun.\nΚαι ήταν, ποτέ δεν έπρεπε να κάνει τίποτα για τον εαυτό του.\nAnd it was, he never really had to do anything for himself.\nY fue que él nunca tuvo que hacer algo por sí mismo.\nEt c'était, il n'avait jamais réellement dû faire quelque chose pour lui-même.\nऔर उसे, खुद को खुद के लिए कुछ नहीं करना पड़ता था।\nВсе дело в том, что у него никогда не было реальной необходимости что-то делать для себя.\nNa ilikuwa, yeye kamwe hakulazimika kufanya kitu chochote kwa ajili yake mwenyewe.\nและมันก็ เขาไม่เคยทำอะเพื่อตนเองเลยจริงๆ\nVe, hiçbir zaman kendi başına hiçbir şey yapmak zorunda değildi.\nAur ye tha, Use kabhi apne lye kuch nahi karna para tha.\nVà đây, anh ấy chẳng bao giờ thật sự làm điều gì đó cho mình.\n而且，我从来不用为自己争取什么。\nكانت لا تزال هناك\nТя все още беше там.\nSie war immer noch da drin.\nΉταν ακόμα εκεί.\nShe was still in there.\nSeguía ahí dentro.\nElle était toujours à l'intérieur.\nवह अभी भी वहां थी ।\nОна все еще была там.\nBado alikuwa mle ndani.\nเธอยังคงอยู่ที่นั่น\nO hala oradaydı.\nوہ ابھی بھی وہاں تھی\nCô ấy vẫn còn ở trong đó.\n她还在那里。\nحسناً أنا مثل، يا إلهي، ورامونا كانت تقف هناك.\nИ аз казах, боже мой, и Рамона стоеше там.\nAlso ich so, Oh mein Gott und Ramona hat da gestanden.\nΈτσι, είμαι σαν, Ωχ Θεέ μου, και η Ραμόνα στεκόταν εκεί.\nSo I'm like, Oh my gosh, and Ramona was standing there.\nAsí que yo pensaba, Dios mío, y Ramona estaba ahí.\nDonc, je suis comme, Oh mon dieu, et Ramona se tenait là.\nतो मै तो था जैसे हे भगवान, और रमोना वहा खड़ी थी ।\nИ я такой, господи боже мой, и тут Рамона.\nKwa hivyo niko kama, Oh, mungu wangu, na Ramona alikuwa amesimama hapo.\nดังนั้นแล้วฉันก็ประมาณว่า พระเจ้าช่วย และราโมนาก็ยืนอยู่ตรงนั้น\nÖyle ki, Aman Tanrım ve Ramona orada duruyordu.\nتو میں پسند کرتا ہوں، اوہ میرے خدا، اور رامون وہاں کھڑا تھا.\nThế là tôi kiểu như là, Ôi trời ơi, và Ramona đang đứng ở đó.\n所以我就像这样，哦，我的天哪，雷蒙娜正站在那里。\nأيضا، اسمحوا لي أن أتحدث عن هذا.\nСъщо, о, нека мина през това.\nLass mich dass hier auch noch durchgehen.\nΕπίσης, άσε με να το ξεπεράσω αυτό.\nAlso, oh, let me go through this.\nTambién, oh, deja que salga de esto.\nEt, oh, laissez-moi le temps de parcourir cela.\nइसके साथ ही, ओह, मुझे इस के माध्यम से जाने दो।\nИ еще, дай-ка я сам с этим разберусь.\nAcha mimi nipitie hili pia.\nนอกจากนี้, โอ้, ให้ฉันไปผ่านนี้ไป\nAyrıca, ah, bundan devam etmeme izin ver.\nمجھے اس میں سے بھی جانے دو.\nNgoài ra, oh, hãy để tôi vượt qua điều này.\n另外，哦，让我看看这个。\nوالحقيقة هي أنها كانت خفيفة !\nА фактът е, че тя беше лека!\nUnd Tatsache ist, dass sie Licht war!\nΚαι το γεγονός είναι ότι ήταν ελαφριά!\nAnd the fact is she was light!\n¡Y en realidad era ligera!\nEt le fait est qu'elle était une lumière !\nऔर हकीकत है कि वह नशे में थी!\nИ факт тот, что она была легкой!\nNa ukweli ni kwamba alikuwa mwepesi!\nแล้วความจริงก็คือ เธอเป็นความสว่าง!\nİşin aslı, hafif biriydi!\nاور حقیقت یہ ہے کہ وہ روشنی تھی\nVà thực tế là cô ấy thật nhẹ nhàng!\n事实是，她很轻松！\nلكن فجأة ، دُعينا للنظر إلى ما كان يطير.\nНо изведнъж бяхме извикани навън, за да видим какво лети.\nAber plötzlich wurden wir gerufen, um zu sehen was fliegt.\nΑλλά ξαφνικά, μας κάλεσε να δούμε τι πετούσε.\nBut all of a sudden, we was called out to look at what was flying.\nPero, de repente, nos llamaron para mirar lo que estaba volando.\nMais tout à coup, nous avons été appelés à regarder ce qui volait.\nलेकिन अचानक,  यह देखने के लिए बुलाया गया था बाहर क्या उड़ रहा था।\nНо внезапно, мы были вызване чтобы взглянуть на что летало\nKwa ghafla aliitana kujua ni nini kilichjokuwa kinapaa.\nแต่ทันทีทันใดนั้น เราก็ถูกเรียกออกไปดูว่ามีอะไรบินอยู่\nAma aniden uçan şeye bakmak üzere çağrıldık.\nلیکن اچانک، ہمیں بلایا گیا تھا کہ کیا دیکھ رہا تھا.\nNhưng đột nhiên, chúng tôi được gọi ra để xem những gì đang bay.\n但突然间，我们被召集去看正在飞行演出。\nسيمزق الورقة ويضعها في الرمل ، رمل رماد السجائر ، ويٌعدها للنار وبحرقها ، وبعد ذلك يحرك الرماد مثل ذلك.\nТой щеше да разкъса хартията и да я сложи на пясъка, силициев пясък, да я запали и да я изгори, а след това да разбърка пепелта ето така.\nEr zerreißte das Papier und lag es in den Sand, Aschenbechersand, legt es ins Feuer und verbrannt es und dann rührte er die Asche.\nΘα σκίσει το χαρτί και θα το βάλει στην άμμο, την άμμο από το τασάκι, θα το βάλει φωτιά και θα το κάψει, και μετά θα ανακατέψει την τέφρα έτσι.\nHe would tear the, the paper up and put it in the sand, ashtray sand, set it on a fire and burn it, and then stir the ashes up like that.\nÉl rasgaba el papel y lo colocaba en la arena,  encendía fuego y lo quemaba en el cenicero de arena, y luego removía las cenizas así.\nIl avait l'habitude de déchirer le le papier et de le mettre dans le sable, le sable du cendrier, d'y mettre le feu et de le brûler, et puis de mélanger les cendres comme ça.\nवह कागज़ को टुकड़ो में फाड कर मिटटी में मिला देता था , ऐशट्रे वाली, आग लगाके उसे जलने देता था , और राख को घूमते रहता था ।\nОбычно он рвал бумагу, бросал ее на землю, поджигал и затем ворошил пепел.\nYeye angeirarua karatasi na kuiweka katika mchanga, mchanga kwa kisahani cha majivu, kuiweka juu ya moto na kuichoma na kisha kuchanganya majivu namna hiyo.\nเขาน่าจะฉีกกระดาษนั่นและใส่มันลงไปในทราย ทรายที่เขี่ยบุหรี่ จุดไฟและเผามัน และจากนั้นก็คนเถ้าถ่านแบบนั้น\nKâğıdı yırtıp kuma, küllü kuma koydu, ateşe verdi ve yaktı, sonra külleri bu şekilde karıştırdı.\nاس نے کاغذ کو پھاڑ کر اسے ریت میں ڈال دیا، اسکرین ریت، اسے آگ لگائے اور اسے جلا دے، اور پھر اس طرح کی چپکے ہلائیں\nAnh ta sẽ xé tờ giấy và đặt nó vào cát, cát gạt tàn, đặt nó trên lửa và đốt cháy nó, và sau đó khuấy đống tro tàn lên như thế.\n他会把纸撕毁，放在沙子里，在烟灰缸的沙子里，将它点燃，然后像那样搅动灰烬。\nأنا لا أعرف ما إذا مكث في أوغوستا بعد ذلك.\nНе знам дали той остана в Огъста след това.\nIch weiß nicht ob er danach in Augusta geblieben ist.\nΔεν ξέρω αν έμεινε στην Αυγκούστα μετά από αυτό.\nI don't know whether he stayed in Augusta after that.\nNo sé si se quedó en Augusta después de eso.\nJe ne sais pas s'il est resté à Augusta après cela.\nमुझे नहीं पता कि उसके बाद अगस्त में रहे थे या नहीं।\nЯ не знаю, остался ли он в Августе после этого.\nSijui kama alikaa Augusta baada ya hayo.\nฉันไม่รู้ว่าเขาอยู่ในออกันตาหลังจากนั้นหรือไม่\nBundan sonra Augusta'da kalır mıyım, bilmiyorum.\nمجھے نہیں پتا اگر اس کے بعد وہ اغستہ میں رہا یا نہیں\nTôi không biết liệu anh ta có ở lại Augusta sau đó hay không.\n我不知道在那之后他是否还留在奥古斯塔。\nكل ما فعلناه، لم يخبرونا مطلقًا بالمكان الذين ذهبوا إليه، حتى عندما غادروا القاعدة للذهاب إلى مكان آخر للبقاء لفترة من الوقت.\nВсичко, което направихме, те никога не ни казаха къде отиват, дори когато напуснаха базата, за да отидат някъде другаде, за да останат за известно време.\nBei allem was wir taten, sagten sie uns nie wo sie hingingen, nicht einmal wenn sie die Basis verließen um für eine Weile anderswo zu bleiben.\nΤο μόνο που κάναμε, ποτέ δεν μας έλεγαν σε ποιο μέρος πήγαιναν, ακόμα και όταν έφευγαν από τη βάση για να πάνε για κάπου αλλού για να μείνουν για λίγο.\nAll we done, they never did tell us any place where they was going, even when they left the base to go to somewhere else to stay for a while.\nTodo lo que hicimos... nunca nos dijeron a dónde iban, ni siquiera cuando salían de la base para ir a algún sitio donde se quedaban algún tiempo.\nTout ce temps, ils ne nous ont jamais dit où ils allaient, même quand ils ont quitté la base pour aller ailleurs pour rester ailleurs un moment.\nसब हो जाने पर उन्होंने हमे कभी नहीं बताया के वो कहा जा रहे है , बेस छोड़कर कुछ देर के लिए कही और ठैरने जाते वक़्त भी नहीं |\nВсе что мы сделали, они никогда не говорили нам, куда они направляются, даже когда уходили с базы и направлялись куда-то еще, и оставались там какое-то время.\nWote tumefanya, hawakutuambia mahali popote walipokuwa wanaenda, hata wakati walipoondoka kwenye kambi kwenda mahali pengine kukaa kwa muda.\nทั้งหมดที่พวกเราทำ, พวกเขาไม่เคยบอกสถานที่ใด ๆ ที่พวกเขาไปแม้กระทั่งเมื่อพวกเขาออกจากสถานที่เพื่อไปยังที่อื่นเพื่ออยู่สักพักหนึ่ง\nYaptığımız tek şey, bir süreliğine başka bir yere gitmek için üssü terk ettiklerinde bile, nereye gittikleri herhangi bir yeri bize hiç söylemediler.\nہم سب نے کئے ہیں، انہوں نے ہمیں کسی ایسی جگہ نہیں بتائی جسے وہ جا رہے تھے، یہاں تک کہ جب وہ بیس بیس سے زائد عرصہ تک رہنے کے لئے کہیں اور جانے کے لئے جاتے تھے.\nTất cả những điều chúng tôi làm, họ chưa bao giờ nói với chúng tôi về bất cứ nơi nào họ đến, thậm chí khi họ rời khỏi trại đi đến một nơi nào khác trong một thời gian.\n我们所做的一切，他们从来没有告诉我们任何他们会去的地方，即使他们离开基地去其他地方暂时逗留的时候。\nأممم ، تحتاج إلى الاتصال برامونا في كونكورد ، ضع في اعتبارك أنها في أحد المكاتب، في الواقع أنها مع عميل ، على طول الطريق عبر المدينة ، ونحن في مونرو ، وهي في كونكورد.\nДа, Трябва да се обадите на Рамона в Конкорд, като имате предвид, че тя е в офис, всъщност тя е при клиент, в целия град, ние сме в Монро, тя е в Конкорд.\nÄh, du musst Ramona in Concord anrufen, allerdings ist sie in einem Büro, eigentlich ist sie in einem Unternehmen, quer durch die Stadt, wir sind in Monroe, sie ist in Concord.\nΜμμ, θέλετε να καλέσετε τη Ραμόνα στο Κόνκορντ, νομίζετε ότι είναι σε ένα γραφείο, στην πραγματικότητα είναι σε έναν πελάτη, στον δρόμο προς την πόλη, είμαστε στο Μονρόε, είναι στο Κόνκορντ.\nUm, You need to call Ramona at Concord, mind you she's at an office, actually she's in a client, all the way across town, we're in Monroe, she's in Concord.\nUm, tienes que llamar a Ramona en Concord. Tenga en cuenta que está en una oficina. En realidad, está en un cliente al otro lado de la ciudad. Estamos en Monroe. Ella está en Concord.\nHum, il vous faut appeler Ramona, elle est à Concord, voyez-vous, elle est dans un bureau, en fait elle est chez un client, à l'autre bout de la ville, nous sommes à Monroe, elle est à Concord.\nउमम, आपको कॉमोकॉन्ड में रमोना को बुलाने की ज़रूरत है, ध्यान दें कि वह एक कार्यालय में है, वास्तव में वह एक ग्राहक में है, पूरे शहर में, हम मुनरो में हैं, वह कॉनकॉर्ड में है।\nУм... тебе нужно позвонить Рамоне в Конкорд. Заметь, она в офисе. На самом деле она у клиента на другом конце города. Мы в Монро, она - в Конкорде.\nYafaa umpigie Ramona huko Concord. Yko ofisini na yyuko na mteja upande uje mwingine wa mji. Sisi tuko Monroe, na yeye yumo Concord.\nเอิ่ม คุณต้องโทรหา Ramona ที่ Concord เตือนไว้ก่อนนะ เธออยู่ที่ออฟฟิศ จริงๆแล้วเธออยู่กับลูกค้าอีกฝั่งนึงของเมือง พวกเราอยู่ที่ Monroe เธออยู่ที่ Concord\nHım, Concord'daki Ramona'yı araman gerek, hatırlatmalıyım ki ofiste, aslında şehir boyunca bir müşteride, biz Monroe'deyiz, o Concord'da.\nآپ کو کونکورڈ میں رامونا کو کال کرنے کی ضرورت ہے، ذہین میں رکھیں کے وہ دفتر میں ہے، دراصل وہ کلائنٹ میں ہے، شہر کے بکل دوسری طرف، ہم مونرو میں ہیں، وہ کونکورڈ میں ہے.\nUm, bạn cần phải gọi cho Ramona ở Concord, hãy nhớ rằng cô ta ở trong văn phòng, đúng hơn thì cô ta ở văn phòng của khách hàng, trên khắp địa bàn thành phố này, chúng ta đang ở Monroe, còn cô ta thì đang ở Concord.\n嗯，你需要打电话给雷蒙娜，她住在在康科德，注意她在办公室，实际上她在客户那里，穿过小镇，我们在门罗，她在康科德。\nهذا نوع فريد من نوعه، آه ، أنا قضيت حوالي 16 سنة من حياتي المهنية في الأنشطة الخاصة.\nУникално е, че съм прекарал около 16 години от кариерата си в Специални дейности.\nDas ist in der Hinsicht einzigartig, dass ich äh 16 Jahre meiner Karriere in Special Activities verbracht habe.\nΑυτό είναι μοναδικό, στο ότι εγώ πέρασα περίπου 16 χρόνια της καριέρας μου στις Ειδικές Δραστηριότητες.\nThat's kind of unique in that I, uh, spent, uh, about 16 years of my career in Special Activities.\nEso es algo único en el sentido de que, eh, pasé cerca de 16 años de mi carrera profesional en actividades especiales.\nC'est un peu unique en ce sens que j'ai passé environ 16 ans de ma carrière dans des activités spéciales.\nये काफी अद्वितीय बात है की मैंने अपने व्यवसाय के १६ वर्ष विशेष कार्यक्रमों में बिता दी ।\nЭто немного уникально в том, что я, эм, провёл, эм, около 16 лет своей карьеры в Специальных Мероприятиях.\nHilo ni jambo la kipekee kuwa nilikuwa kwa miaka 16 ya kazi yangu katika shughuli maalum\nนั่นคือประเภทความเฉพาะในนั้น ฉัน เอ่อ ใช้ เอ่อ ราว ๆ  16 ปีในอาชีพของฉันในกิจกรรมพิเศษ\nKariyerimin 16 yılını Özel Faaliyetler'de harcamama sıradışı denebilir.\nاس طرح میں خاص سرگرمیوں میں اپنے کیریئر کے تقریبا 16 سال میں، میں نے خرچ کیا.\nTôi là duy nhất, tôi dành khoảng 16 năm sự nghiệp của mình cho Các Hoạt Động Đặc Biệt\n这是特別的，因为我，呃，花了，呃，约16年的时间在特别活动中度过我的职业生涯。\nقالوا، نحن ندفع للحصول على مكان من أجلك تقيم فيه.\nТе казаха: Плащаме за място, където да останете.\nSie sagten: Wir zahlen für einen Platz, damit du bleibst.\nΕίπαν, πληρώνουμε για ένα μέρος για να μείνετε.\nThey said, We're paying for a place for you to stay.\nDijeron: Estamos pagando por un lugar para que te quedes.\nIls ont dit, nous payons pour une place pour que vous restiez.\nउन्होंने कहा, हम आपके लिए रहने के लिए एक जगह का भुगतान कर रहे हैं।\nОни сказали: Мы платим за место для вас.\nWakasema, Tunalipia mahali ambapo utakaa.\nพวกเขาบอกว่าเรากำลังจ่ายค่าที่อยู่เพื่อให้คุณได้อยู่\nKalacağın yerin ödemesini yapıyoruz, dediler.\nانہوں نے کہا، ہم آپ کے ٹھہرنے کے لئے جگہ کی قیمت ادا کر رہے ہیں.\nHọ nói, Chúng tôi đang trả tiền cho nơi bạn ở.\n他们说，我们付钱租一个地方给你用作停留之用。\nحسنا ، وفي اليوم التالي ، طبعا ، الرئيس كينيدي ،إيه، حاصر كوبا و،إيه وسفننا أوقفت سفينة روسية كانت متجهة خارج كوبا من الداخل ، ووجدوا صواريخ عليها .\nНа другия ден, разбира се, президентът Кенеди блокира Куба и ъ-ъ нашите кораби спряха руски кораб, който се намираше точно извън кубински териториални води и откриха ракети на него.\nNun, am nächsten Tag blockierte natürlich Präsident Kennedy Kuba und, äh, unsere Schiffe stoppten ein russisches Schiff, das außerhalb von Kuba unterwegs war, und sie fanden Raketen auf diesem.\nΛοιπόν, την επόμενη μέρα, φυσικά, ο Πρόεδρος Kennedy, απομόνωσε την Κούβα, και τα πλοία μας σταμάτησαν ένα ρωσικό πλοίο που βρισκόταν λίγο έξω από την Κούβα και βρήκαν πυραύλους πάνω του.\nWell, the next day, of course, President Kennedy, uh, blockaded Cuba, and, uh, our ships stopped a Russian ship that was headed just outside of Cuba within, and they found missiles on them.\nBueno, al día siguiente, por supuesto, el presidente Kennedy, mm..., bloqueó Cuba y, mmm..., nuestros buques detuvieron un barco ruso que se dirigía a las afueras de Cuba y encontraron misiles.\nEh bien, le lendemain, bien sûr, le président Kennedy a bloqué Cuba, et nos navires ont arrêté un navire russe qui se dirigeait juste à l'extérieur de Cuba, et ils ont trouvé des missiles sur ces eux.\nखैर अगले दिन ही  राष्ट्रपति कैनेडी ने क्यूबा को निस्र्द्ध कर दिया , फिर हमारी जहाजों ने रुसी जहाज रोका जो क्यूबा से बहार निकल ही रहा था , और उस जहाज पर मिसाइल मिली |\nНу, а на следующий день, разумеется, Президент Кеннеди объявил блокаду Кубы, и наши корабли остановили у кубинских берегов направлявшийся на Кубу российский корабль, и у него на борту нашли ракеты.\nWajua siku iliyofuata Rais Kennedy alizuilia Cuba na meli zetu zikasimamisha meli moja ya Urusi ilikuwa ikielekea Cuba na wakapata vilipuzi ndani.\nเอาละในวันถัดไป แน่ละ ท่านประธานาธิบดีเคนเนดี้ เอ่อ ได้ปิดล้อมคิวบา และก็ เอ่อ เรือของเราได้หยุดเรือของรัสเซียที่มุ่งหน้าเข้าคิวบาและพวกเขาก็พบขีปนาวุธในนั้น\nTabii ertesi gün, Başkan Kennedy Küba'yı abluka altına aldı ve gemilerimiz, Küba'nın hemen dışındaki bir Rus gemisini durdurdu ve üzerinde füzeler buldular.\nپھر دوسرے دن صدر کینڈی نے کوبا کی ناکہ بندی کی ، اور ہمارے جہاز نے ایک روسی جہاز کو روکا جو کہ کوبا کی طرف آرہی تھی، اس میں ان کو میزاۂل ملے۔\nVâng, tất nhiên Tổng thống Kennedy đã phong tỏa Cuba ngày hôm sau và các tàu của chúng tôi đã chặn một con tàu của Nga đang đứng ngay bên ngoài ranh giới với Cuba và họ đã tìm thấy tên lửa ở đó.\n当然，第二天，肯尼迪总统呢，呃，封锁了古巴，而且，呃，我们的船只停住了一艘正在古巴外面航行的俄国船，而且他们发现了船上的导弹。\nوقد أتى الرجل على أي حال.\nКакто и да е, мъжът влиза.\nJedenfalls kommt der Mann herein.\nΕν πάση περιπτώσει εισέρχεται ο άνθρωπος.\nAnyhow the man comes in.\nEn cualquier caso, el hombre entra.\nQuoi qu'il en soit l'homme entre.\nकिसी भी तरह आदमी अंदर आता है।\nКак бы то ни было, человек входит.\nMwanaume huingia kwa vyovyote vile.\nถึงอย่างไรก็ตามชายคนนั้นเข้ามา\nHer neyse, adam içeri girer.\nبہارل وہ آدمی اندر اجاتا ہے\nDù sao thì người đàn ông vẫn bước vào.\n不管怎么说，这个人进来了。\nأم ، كان جدي وجدتي دائما ناس محبين جدا جدا وكان والدأي يقضوا وقتا ممتعا هناك .\nМоите баби и дядовци винаги са били много, много любящи хора и някои от тях, родителите ми бяха, и просто се забавлявахме там.\nÄhm, meine Großeltern waren immer sehr, sehr liebevolle Menschen und einige von meinen Eltern waren auch so, und wir hatten einfach eine tolle Zeit dort unten.\nΟι παππούδες μου ήταν πάντα πολύ αγαπητοί άνθρωποι και μερικοί από τους γονείς μου ήταν και είχαμε καλές αναμνήσεις από αυτούς.\nUm, my grandparents were always very, very loving people and some of, my parents were and we'd just have a great time down there.\nMmm..., mis abuelos siempre fueron personas muy, muy cariñosas y algunas veces estuvieron mis padres y sencillamente lo pasamos muy bien allá abajo.\neuh, mes grand-parents ont toujours été des gens très très affectueux et certains de, mes parents étaient et nous passions de très bons moments là-bas.\nमेरे नाना नानी बोहोत ज़्यादा प्यारे व्यक्ति थे , और हम हमारे माता पिता के साथ काफी अच्छा समय बिताया था ।\nМои дедушки с бабушками всегда были добрые, очень душевные люди и одни из, мои родители тоже, и нам всем там было очень хорошо.\nMababu zangu walikuwa watu wenye mapenzi sana na wengine wa wazazi wangu walikuwa na na wakati mzuri sana kukaa nao.\nอืม ปู่ย่าตายายของฉันมักเป็นคนโรแมนติกและพ่อแม่ของฉันก็เป็น และพวกเรามีเวลาที่ดีที่นั่น\nAh, büyükbabamlar her zaman çok, çok sevgi dolu insanlardı ve ailemden bazıları da öyle ve orada harika zaman geçirdik.\nمیرے دادا دادی بہت بہت پیار کرنے والے لوگ تھے ، ان کے ساتھ بڑا مزہ آتا تھا۔\nỪm, ông bà của tôi luôn là những người rất đáng yêu và đại loại vậy, bố mẹ và chúng tôi thực sự đã có một khoảng thời gian tuyệt vời ở đó.\n嗯，我的祖父母总是非常非常的有爱心，而我的父母也是，我们在下边享受美好时光。\nأعتقد أن هذا هو السبب في أنني أتذكر ذلك.\nМисля, че затова си спомням това.\nIch denke, das ist der Grund, warum ich mich daran erinnere.\nΝομίζω ότι αυτός είναι ο λόγος που το θυμάμαι αυτό.\nI think that's why I remember that.\nCreo que por eso lo recuerdo.\nJe pense que c'est pour cela que je m'en rappelle.\nमुझे लगता है कि यही वजह है कि मुझे वो याद है।\nНаверное, потому это мне и запомнилось.\nNadhani ndiyo maana ninakumbuka hilo.\nฉันคิดว่านั่นคือสาเหตที่ฉันยังจำได้\nSanırım bu yüzden bunu hatırlıyorum.\nشاءد اس لۂے مجھے یاد ہے\nTôi nghĩ đó là lý do tại sao tôi nhớ.\n我想这就是为什么我记得。\nوذلك يجعلك تشعر بالسوء.\nИ те кара да се чувстваш ужасно.\nUnd du fühlst dich einfach schrecklich.\nΚαι απλά σε κάνει να νιώθεις απαίσια.\nAnd it makes you feel just awful.\nY eso te hace sentir fatal.\nEt ça vous fait vous sentir horriblement mal.\nऔर इससे आपको भयानक लग रहा है।\nИ это заставляет тебя чувствовать себя просто ужасно.\nYakufanya tu uhisi vibaya.\nและมันทำให้เธอรู้สึกแย่\nVe bu kendini çok kötü hissetmene neden olur.\nاور یہ آپ کو بہت برا محسوس کرواتا ہے.\nVà nó khiến bạn cảm thấy thật khủng khiếp.\n它会使你感觉很糟糕。\nليس عليك البقاء هناك.\nНе е нужно да оставате там.\nDu musst dort nicht bleiben.\nΔεν χρειάζεται να μείνεις εκεί.\nYou don't have to stay there.\nNo tienes que quedarte allí.\nVous n'avez pas à rester là.\nआपको वहां रहने की जरूरत नहीं है\nТебе не обязательно здесь оставаться.\nSi lazima ukae huko.\nคุณไม่จำเป็นต้องอยู่ที่นั่น\nOrada kalmak zorunda değilsin.\nتمہیں وہاں رہنے کی ضرورت نہیں ہے.\nBạn không cần phải ở đó.\n你不需要留在那里。\nلم يكن لدي الوقت للدخول في جميع الأنواع .\nНямах време да навлизам в каквито и да е неща.\nIch hatte keine Zeit, um alles Mögliche zu betreten.\nΔεν είχα χρόνο για να ανακατευτώ για οποιονδήποτε λόγο.\nI didn't have time to enter in all kinds of whatever.\nNo tuve tiempo para tratar todo tipo de lo que sea.\nJe n'ai pas eu le temps d'entrer dans toutes sortes de choses.\nमेरे पास ऐसे फ़िज़ूल चीजों में दखलंदाज़ी करने का वक़्त नहीं था  ।\nУ меня не было времени входить во все детали чего бы то ни было.\nSikuwa na muda wa kuingia katika kila aina ya chochote.\nฉันไม่มีเวลาที่จะเข้าไปในทุกสิ่ง\nHer türlü şeyi girecek zamanım olmadı.\nمیرے پاس ہر چیز میں حصہ لینے کے لۂے فضول وقت نہیں ہے\nTôi không có thời gian để tham gia tất cả.\n我没有时间参加任何活动。\nوكان من أحد واجباتي تدريب الأفراد على كيقية وضع المظلات على مشغل الأسلحة النووية التي تفجر القنبلة الزرية نفسها .\nИ една от моите работи по това време беше да тренирам хора как да слагат парашути върху спусъка на ядрени оръжия, което детонира самата атомна бомба.\nUnd hmm, eine meiner Aufgaben zu der Zeit war es Einzelpersonen beizubringen, wie sie Fallschirme auf die Zünder von Atomwaffen setzen, damit hmm, wodurch die Atombombe explodiert.\nΚαι, μία από τις δουλειές μου εκείνη την εποχή ήταν να εκπαιδεύσω τα άτομα πώς να βάλουν τα αλεξίπτωτα σε ένα πυροκροτητή πυρηνικών όπλων, το οποίο εκραγεί την ίδια την ατομική βόμβα.\nAnd, uh, one of my jobs at the time was training individuals how to put parachutes on a nuclear weapons trigger to, uh, which detonates the atomic bomb itself.\nY, uno de mis trabajos en ese momento fue entrenar a las personas para que pongan paracaídas en un disparador de armas nucleares, que detona la propia bomba atómica.\nEt euh, l'un de mes jobs de l'époque était de former les gens à mettre des parachutes sur les armes nucléaires pour que les bombes atomiques explosent toutes seules.\nऔर उस समय मेरी नौकरी में एक कार्य लोगों को परमाणु हथियारों पर पैराशूट देखने का प्रशिक्षण देना था जिससे परमाणु बम खुद ही फट जाता है।\nИ, эм, одной из моих задач в то время было обучение людей технике  установки парашютов на детонатор ядерных орудий, который, эм, взрывает саму бомбу.\nMojawapo ya kazi zangu ilikuwa kufunza watu jinsi ya kuweka maparachuti kwenye silaha za nyuklia ambazo hufanya bomu ilipuke kivyake.\nและหนึ่งในงานของฉัน ณ ตอนนั้น คือ สอนคนให้ติดร่มชูชีพกับไกอาวุธนิวเคลียร์ ซึ่งเป็นตัวจุดระเบิดปรมาณูในตัว\nVe, o zamanlar işimden biri, bireylere atom bombasının kendisini patlattığı nükleer silah tetikleyicisine nasıl paraşüt atılacağını öğretmekti.\nاور، اوہ، میرا وقت میں سے ایک میں سے ایک شخص افراد کو تربیت دے رہا تھا کہ وہ کس طرح جوہری ہتھیاروں پر پیراشوٹ ڈالۓ، یعنی، جو جوہری بم خود کو دھمکی دیتا ہے.\nVà, ừ, một trong những công việc của tôi vào thời điểm đó là huấn luyện các cá nhân cách đặt dù trên nút kích hoạt vũ khí hạt nhân, ừ, để tự phát nổ bom nguyên tử.\n并且，呃，我当时的工作之一是训练个人如何将降落伞投到核武器引爆器上，呃，以引爆原子弹。\nاعتادت جدتي أن تخبرني عدا كبيرا من الحكايا عن سنوات ترعرها و،إيه، خاصة ،إيه ، اعتادت أن تتحدث عن عائلتها وكيف كانت في  تلك الأوقات .\nБаба ми ми разказваше много различни истории за годините, в които е израстнала, като особено обичаше да говори за своето семейството и за това какво е било през онези времена.\nMeine Großmutter erzählte mir viele verschiedene Geschichten über ihre heranwachsenden Jahre und, ähm, besonders, uh, sprach sie über ihre Familie und wie es damals war.\nΗ γιαγιά μου μου έλεγε πολλές διαφορετικές ιστορίες για το πως μεγάλωσε και, μάλιστα, μίλαγε για την οικογένειά της και πώς ήταν εκείνη την εποχή.\nMy grandmother used to tell me a lot of different stories about her growing-up years and, uh, especially, uh, she used to talk about her family, and what it was like during those times.\nMi abuela solía contarme muchas historias diferentes sobre sus años de infancia y, especialmente, solía hablar sobre su familia y cómo era la vida en esa época.\nMa grand-mère me racontait de nombreuses choses sur son enfance et euh, en particulier, elle parlait de sa famille et de comment était la vie à cette époque.\nमेरी दादी मुझे उनकी बड़ी होने का और अलग कहानियाँ बताया करते थे, और, खासकर, वो अपने परिवार के बारे मे बताया करते थे, और, उस जमाने मे वो सब कैसे था।\nМоя бабушка рассказывала мне много разных историй о своей молодости, семье и о том, как жилось  в те времена.\nBibi yangu alikuwa ananiambia hadithi nyingi kuhusu miaka yake alipokuwa anakuwa mkubwa, hasa, uh, alikuwa anazungumza juu ya familia yake, na jinsi ilivyokuwa wakati huo.\nยายของฉันเคยเล่าเรื่องต่าง ๆ มากมายเกี่ยวกับปีเธอที่เติบโตขึ้นมา และโดยเฉพาะอย่างยิ่งเธอเคยพูดเรื่องครอบครัวของเธอและเล่าว่าครอบครัวเธอเป็นอย่างไรในช่วงเวลานั้น\nBüyükannem bana kendi çocukluğu hakkında bir sürü farklı hikaye anlatırdı, ve özellikle ailesi ve o zamanlar hakkında konuşurdu.\nمیری دادی اپنی جوانی کے بارے میں بہت سی مختلف کہانیاں سناتی تھیں خاص طور پراپنے خاندان اور پرانے زمانے کے بارے میں۔\nBà tôi thường kể cho tôi nghe rất nhiều câu chuyện khác nhau về những năm trưởng thành của bà và, uh, đặc biệt là, uh, bà thường nói về gia đình của mình, và nó như thế nào trong những thời gian đó.\n我的祖母曾经跟我讲过很多关于她成长年代的不同故事，特别是，呃，她曾经讲过她的家庭，以及那段时间的情况。\nكان لدينا نفس بدلات الضغط التي يرتديها رواد الفضاء، غير أن بدلاتنا كانت من الفضة. الحذاء و كل شيء كان مصنوعا من الفضة لعكس الحرارة بطبيعة الحال.\nВ самолета имахме височинно компенсиращи костюми, точно като на астронавтите, само че нашите бяха изцяло сребристи, заедно с обивките, за да отразяват топлината, разбира се.\nWir trugen Druckanzüge im Flugzeug, so wie ihn Astronauten tragen, nur dass unsere vollkommen silberfarben waren, silberfarbene äh Stiefel und alles, um natürlich die Hitze zurückzuwerfen.\nΕίχαμε στολές πλήρους πίεσης εκεί μέσα στο αεροσκάφος, όπως τις φορούσαν οι αστροναύτες, αν και η δικιά μας ήταν ασημένια, ασημένιες και οι μπότες και όλα, για να αντανακλούν τη θερμότητα, φυσικά.\nWe had full pressure suits there in the aircraft just, just like the astronauts wore, except ours was totally silver, silver uh, boot and everything, to reflect heat, of course.\nTeníamos trajes completos de presión en el avión, como los que usaban los astronautas, excepto que el nuestro era totalmente plateado, plateado, con bota y todo, para reflejar el calor, por supuesto.\nNous avions des combinaisons pressurisées dans l'avion, juste comme celles que portent les astronautes, sauf que le nôtre était totalement argentée, avec bottes argentées et tout, pour refléter la chaleur, bien sûr.\nहमारे पास पूर्ण दबाव सूट है हवाईजहाज में , जैसे अंतरिक्ष यात्री पहनते है , बस हमारे पुरे चांदी जैसे चमकीले थे , बूट और सबकुछ चांदी जैसे , ताकि गर्मी को विपरीत धकेल सके ।\nТам, в самолете, на нас были высотные скафандры, такие же как у космонавтов, только наши были полностью блестящие, с отражающими ботинками и всем прочим — для отражения тепла, само собой.\nShinikizo kuu linafaa ndege,kana lile  wana anga wanafaa kuliko letu lenye  fedha tupu,fedha ehe kutoka viatu hadi kila kitu,kuzuia joto.\nเรามีชุดปรับความดันในเครื่องบิน เหมือนกับที่นักบินอวกาษใส่ยกเว้นของเราทำมาจากเงินทั้งหมดเงิน รองเท้าบู๊ตและทุกอย่างสะท้อนความร้อนอย่างแน่นอน\nOrada uçakta tıpkı astronotların giydikleri gibi tam basınçlı giysilerimiz vardı, sadece bizimki tamamen gümüş rengiydi, gümüş, botlar ve her şey, elbette ısıyı yansıtmak için.\nجہاز ميں ہمارے پاس مکمل دباؤ کے لباس تھے، جيسے خلباز پہنتے تھے، ليکن ہمارے  نقوۂی رنگ کے تھے ، نقوۂی ، بوٹ اور سب کچ، , گرمی کی عکاس کرنے کے لۂے۔\nChúng tôi đã có đầy đủ áp lực phù hợp có trong máy bay chỉ, giống như các phi hành gia mặc, ngoại trừ chúng tôi là hoàn toàn bạc, bạc uh, khởi động và tất cả mọi thứ, để phản ánh nhiệt, tất nhiên.\n我们在飞机上有像宇航员穿着的全套压力服，只是我们的完全是银色，银色哦，靴子和所有东西，当然是为了反射热量。\nوكان هو ، لم يكن جدي رجلا  ظريفا.\nИ той беше, дядо ми не беше добър човек.\nUnd er war, mein Großvater war kein netter Mensch\nΚαι ήταν ο παππούς μου, δεν ήταν καλός άνθρωπος.\nAnd he was, my grandfather was not a nice man.\nY lo fue, mi abuelo no fue un buen hombre.\nEt il l'était. Mon grand-père n'était pas un homme gentil.\nऔर वह, मेरे दादा एक अच्छे आदमी नहीं थे ।\nИ он был, мой дед не был хорошим человеком.\nNa alikuwa, babu yangu hakuwa mtu mzuri.\nเขาเป็น ปู่ของฉันไม่ใช่ผู้ชายที่ดี\nVe o, büyükbabam iyi bir adam değildi.\nاور وہ ، میرے دادا اچھے آدمی نہیں تھے۔\nVà ông ấy, ông của tôi, không phải là một người tử tế.\n他是我的祖父，他不是一个好人。\nإنها  30 أو 40 طائرة من طراز U2 ، وبدأنا تدريب الطيارين الصينيين والطيارين البريطانيين في جميع أنحاء العالم الذي كنا حلفاء معه.\nТова са 30 или 40 самолета U2 и бяхме започнали да обучаваме в тях китайски пилоти, британски пилоти, почти от целия свят, с който бяхме съюзници.\nEs sind 30 oder 40 U2 Flugzeuge und wir haben begonnen Chinesische und britische Piloten, praktisch Piloten von Verbündeten auf der ganzen Welt, in ihnen zu trainieren\nΕίναι 30 ή 40 αεροσκάφη U2 και είχαμε ξεκινήσει την εκπαίδευση των Κινέζων πιλότων, των Βρετανών πιλότων σε αυτά, σε όλο τον κόσμο με τους οποίους είμαστε σύμμαχοι.\nIt's 30 or 40 U2 aircraft, and we'd started the training Chinese pilots, British pilots in them, just all over the, the world that we were allies with.\nSon 30 o 40 aviones U2, y habíamos empezado a entrenar a pilotos chinos, pilotos británicos con ellos, justo en todo el mundo con el que éramos aliados.\nC'est un avion U2 30 ou 40, et nous avions commencé à former des pilotes chinois et britanniques, dans le monde entier, là où nous avions des alliés.\nवहा  ३० या ४० यू २ हवाईजहाज थे , और हमने चीनी पायलट से शुरू की , ब्रिटिश पायलट थे उसमे ,हम  उनके साथ पुरे विश्व में जो भी  हमारे सहयोगी दलों के ,उनके साथ  थे ।\nЭто 30 или 40 самолетов U2, и мы начали обучение в них китайских пилотов, британских пилотов, просто во всем, всем мире с котором мы были союзниками.\nNi ndege 30 au 40  za U2, na tulianza mafunzo ya majaribio ya marubani wa Kichina, marubani wa Uingereza ndani yao, kote duniani, ulimwengu ambao tulikuwa washirika nao.\nมันเป็นเครื่องบิน 30 หรือ 40 U2 และเราได้เริ่มฝึกนักบินจีนนักบินอังกฤษในนั้น แค่ทั่วโลกที่เราเป็นพันธมิตรกับ\nO 30 veya 40 U2 uçağıdır ve dünyanın tamamen her yerinde, ittifaklarımızın bulunduğu dünyada, o uçakların içinde Çinli pilotları, İngiliz pilotları eğitmeye başlamıştık\nیہ 30 یا 40 یو2 طیارہ ہے، اور ہم تربیت چینی پائلٹ، ان میں برطانوی پائلٹ شروع کر دی ہے، صرف دنیا بھر میں، کہ ہم اتحادی ہیں\nĐó là 30 hoặc 40 máy bay U2, và chúng tôi đã bắt đầu huấn luyện các phi công Trung Quốc, phi công Anh, chỉ trên khắp thế giới mà chúng tôi là đồng minh.\n这是一架30或40 U2飞机，我们开始训练中国飞行员，英国飞行员，只是我们在世界各地的同盟。\nلذا ، إنها تبدو ، حسناً ، أنظر إلى هذا في مثل هذه الشركة\nИ тя – добре де, виж, виж в тази и тази компания.\nAlso, sie ist wie, Ok schau, schau das in dem und dem Unternehmen nach\nΈτσι, μοιάζει, λοιπόν  κοιτάξτε, κοιτάξτε μια τέτοια παρέα.\nSo, she's like, Well look, look up this in such and such company.\nAsí que, dice Bueno busca, busca esto en esta y esta empresa.\nDonc elle est comme, eh bien cherche cela dans telle et telle entreprise.\nऔर वह कहती थी की यहाँ इस प्रकार के  दफ्तर में ढूंढो ।\nИтак, она такая... Ну, смотри, посмотри на это в такой-то компании.\nKwa hivyo, yeye alikuwa kama, tutaangalia, na kufanya upelelezi katika kampuni fulani kwa fulani.\nดังนั้น,  เธอเป็นแบบ,  ดูดีๆ,  ดูบริษัทดังกล่าวให้ดี\nÖyleyse, öyle görünüyor ki, böyle bir şirkete bakın.\nتو، وہ پسند ہے، ٹھیک ہے، اس طرح کی ایسی کمپنی میں دیکھو.\nVì vậy, cô ấy giống như, à nhìn đi, trong mong vào công ty như vậy.\n所以，她的表情是，看，查查，看看这个公司。\nقال أنهم اتجهوا إلى الشمال.\nКаза, че са заминали на север.\nEr sagte dass sie in den Norden gegangen sind.\nΕίπε ότι είχαν ανεβεί βόρεια.\nHe said that they had gone up North.\nÉl dijo que habían subido al norte.\nIl a dit qu'ils étaient partis vers le nord.\nउसने कहा कि वे लोग उत्तर की ओर चले गए थे।\nОн сказал, что они поднялись на север.\nAlisema walikuwa wameenda kaskazini.\nเขาพูดว่าพวกเขาได้ไปทางเหนือ\nKuzey’e gittiklerini söyledi.\nانہوں نے کہا کہ وہ شمال کی طرف گئے تھے.\nAnh ta nói rằng họ đã đi lên phía Bắc.\n他说他们已经北上了。\nكان ذلك ، كان ذلك يومًا مخيفًا جدًا.\nТова беше един доста страшен ден.\nDas war, das war ein ziemlich beängstigender Tag.\nΑυτή ήταν μια αρκετά τρομακτική ημέρα.\nThat was, that was a pretty scary day.\nEse fue..., ese fue un día bastante aterrador.\nC'est vrai, c'était une journée terrifiante.\nवही था, वह एक बहुत डरावना दिन था।\nЭто был довольно страшный день.\nHii ilikuwa, hii ilikuwa siku ya kuogofya mno.\nนั่นเป็นวันที่สุดแสนน่ากลัว\nBu, bu biraz korkutucu bir gündü.\nیہ تھا، یہ ایک خوبصورت دن تھا\nĐó là một ngày đáng sợ một chút.\n这是个非常让人害怕的一天。\nآه، حسنا انها، اه، السرعات أصبحت أعلى وأسرع وأسرع حتى ننشرها في الخارج.\nСкоростите ставаха все по-бързи, докато се разположхме в чужбина.\nNunja, die Geschwindigkeit wurde immer höher und höher und höher, bis wir in Übersee eingesetzt wurden.\nΜμ, λοιπόν, και οι ταχύτητες ήταν γρήγορες, πιο γρήγορες και ακόμα πιο γρήγορες μέχρι να αναπτυχθούμε πάνω από την θάλασσα.\nUh, well it's, uh, the speeds got faster, faster, and faster until we deployed overseas.\nEsto, bueno es, osea... la velocidad aumentó, aumentó y aumentó hasta que desplegamos en el extranjero.\nEuh, eh bien, les vitesses ont augmenté, augmenté, et encore augmenté jusqu'à notre déploiement outre-mer.\nअह, अहा यह, अह, गति तेज हो गई, तेज, और तेज जबतक की हमने विदेशियों को काम पर नही लगाया।\nНу, в общем, ситуация развивалась всё быстрее, быстрее и быстрее пока, наконец, нас не дислоцировали заграницу.\nHakika, ni vizuri, miendo zikawa za kasi na kasi, na kasi mno, mpaka tukahamisha nje ya nchi.\nเอ่อ คือว่า เอ่อ ความเร็วมันเร็วขึ้นเรื่อยๆ จนพวกเราเคลื่อนกำลังพลไปต่างประเทศ\nUh, peki, bu hızlar denizaşırı konuşlandırılana kadar daha hızlı, daha hızlı ve daha hızlı oldu.\nاوہ، ٹھیک ہے يہ، ہے،جب تک ہم بیرون ملک تعینات نہیں کیے جائیں گے تو  رفتار، تیز,تیزاور تیز  ہوجائے گی.\nUh, vâng, uh, tốc độ nhanh được tăng nhanh hơn, nhanh hơn và nhanh hơn nữa cho đến khi chúng tôi triển khai ở nước ngoài.\n嗯，没错，速度更快，更快，更快，直到我们在海外部署。\nحسنًا ، أنا ، على أي حال ، أه ، أه ، هؤلاء الثلاثة ، أه ، انت وطيارين ، أه ، مكتب الرئيس كينيدي في واشنطن مع الجنرال ماي.\nТака че, добре, това са тримата пилоти на U2 пилоти, които, ъ, офиса на президента Кенеди във Вашингтон с генерал Мей.\nSo, gut ich, äh, wie auch immer, äh, äh, dies sind die drei, äh, U2 Piloten die, äh, President Kennedy's Büro in Washington mit General May.\nΈτσι, τέλος πάντων, αυτοί είναι οι τρεις, U2 πιλότοι, στο γραφείο του Προέδρου Kennedy στην Ουάσινγκτον με τον Στρατηγό May.\nSo, well I, uh, anyway, uh, uh, this is the three, uh, U2 pilots that, uh, President Kennedy's office in Washington with General May.\nEntonces, bueno, de todos modos, estos son los tres, eh, los pilotos U2 que están en la oficina del presidente Kennedy en Washington con el general May.\nDonc, bien je, uh, de toute façon, uh, uh, c'est le trois, uh, le pilote U2 qui, uh, le bureau à Washington du président Kennedy avec le général May.\nतो, मैं अच्छी तरह से, उह, वैसे भी, उह, उह, यह तीन है, उह, यू 2 पायलटों कि, उह, जनरल मई के साथ वॉशिंगटन में राष्ट्रपति कैनेडी के कार्यालय मे हूँ ।\nНу и, в общем, э-э-э, это - вот эти три пилота У-2, которые, м-м-м, в кабинет к Президенту Кеннеди в Вашингтоне вместе с Генералом Мэем.\nKwa hiyo, vizuri, mimi, uh, hata hivyo, uh, uh, hawa ndio marubani watatu, uh, wa U2 ambao, ofisi ya Rais Kennedy huko Washington pamoja na Jenerali May.\nดังนั้น ผมก็ เอ่อ อย่างไรก็ตาม เอ่อ เอ่อ นี่คือ สามอย่าง เอ่อ นักบิน U2 ซึ่ง เอ่อ ในที่ทำงานประธานาธิบดีเคนเนดี ในวอชิงตัน กับพลเอกเมย์\nÖyleyse, peki, bu arada, bu U2 pilotları, General Kennedy ile birlikte Washington'daki Başkan Kennedy'nin ofisi.\nتو میں، اچھا خیر، تین، دو ہواباز ہیں جو کہ صدر کینیڈی کے واشنگٹن والے دفتر میں جرنیل مئ کے ساتھ ہیں ۔\nThế là, à, tôi, ừ, dù sao, ừ, ừ, đây là ba, ừ, phi công U2, ừ, văn phòng Tổng thống Kennedy ở Washington với Đại tướng May.\n所以，我，呃，不管怎样，呃，呃，这是那三个，呃，U2的飞行员，呃，肯尼迪总统在华盛顿会见梅将军的办公室。\nلما لا تدعني أتناول كوب كبير من الشيكولاتة باللبن أولًا، قبل أن تصفعني على المؤخرة؟\nПреди да ме напляскаш, защо просто не ми позволиш първо да изпия една голяма чаша шоколадово мляко?\nBevor du mir den Hintern versohlst, wieso lässt du mich nicht erst ein großes Glas Schokoladenmilch trinken?\nΠριν μου δώσεις ένα χαστούκι, γιατί δεν μου αφήνεις πρώτα ένα μεγάλο ποτήρι γάλα με σοκολάτα;\nBefore you give me a spanking, why don't you just let me have one big glass of chocolate milk first?\nAntes de darme una bofetada, ¿por qué no me dejas tomarme primero un buen vaso de chocolate con leche?\nAvant que tu ne me donnes une fessée, pourquoi ne me laisserais-tu pas juste avoir un grand verre de chocolat au lait d'abord ?\nइससे पहले कि आप मेरी पिटाई  करें ,तो आप मुझे पहले चॉकलेट दूध का एक बड़ा गिलास क्यों न दें?\nПрежде чем ты отшлепаешь меня, почему бы сначала не дать мне выпить один большой стакан шоколадного молока?\nKabla hujanipa mpigo wa mzaha, Kwa nini hukuniruhusu nipate kikombe kikubwa cha glasi cha maziwa ya chokoleti.\nก่อนคุณจะตีฉัน ทำไมคุณไม่ให้ฉันดื่มนมช็อคโกแลตสักแก้วก่อนล่ะ?\nBana bir şaplak atmadan önce neden sadece bir bardak çikolatalı süte izin vermiyorsun?\nآپ کو مجھے سپانکنگ  کرنے سے پہلے,آپ نے پہلے ہی مجھے چاکلیٹ دودھ کا پہلا بڑا گلاس کیوں نہیں دیا ہے؟\nTrước khi bạn trêu tôi, tại sao bạn không chỉ cho tôi có một ly sữa sô cô la trước?\n在你打我屁股之前，为什么不先让我喝一杯巧克力牛奶？\nوبعد الدخول إلى كل شيء ، يمكنك الانتقال من هناك.\nИ след като веднъж вкарате всичко, можете да продължите оттам.\nUnd dann nachdem du alles eingetragen hast kannst du von dort weitermachen.\nΚαι έπειτα, αφού συμπληρώσεις όλα τα στοιχεία, μπορείς να συνεχίσεις από εκεί.\nAnd then once you get everything entered you can go on from there.\nY entonces, una vez tienes todo introducido, puedes seguir a partir de ahí.\nEt ensuite, une fois que vous avez tout entré, vous pouvez continuer à partir de là.\nऔर एक बार यदि सरे चीज़े भर दोगे वहा से आगे जा सकते हो ।\nА потом, когда всё внесёшь, можно на этой основе двигаться дальше.\nNa kisha ukipata vitu vyote vimeingizwa unaweza kuendelea kutoka hapo.\nหลังจากที่ที่คุณพาทุกคนเข้ามาแล้ว คุณสามารถไปได้จากที่นั่น\nVe daha sonra her şeyi girdikten sonra oradan devam edebilirsin.\nاور پھر ایک بار جب آپ سب کچھ داخل آپ حاصل کر سکتے ہیں تو آپ وہاں سے آگے بڑھ سکتے ہیں.\nVà khi bạn cho mọi thứ vào, bạn có thể bắt đầu từ đó.\n然后一旦你把所有东西都弄进去后，你就可以从这里开始，继续前进。\nلا، لأكون صادقًا، لم أقرأ أي من الكتب التي كان من المفترض أن أقرأها.\nНе, за да бъда честен, никога не съм чел някоя от книгите, които трябваше да прочета.\nHmm, nein, ehrlich gesagt habe ich noch nie eines der Bücher gelesen, die ich lesen sollte.\nΑχμ, για να είμαι ειλικρινής, ποτέ δεν διάβασα κανένα από τα βιβλία που υποτίθεται ότι έπρεπε.\nUm, no, to be honest, I never read any of the books I was supposed to.\nHum, no, para ser sincero, nunca leí ninguno de los libros que se suponía tenía que leer.\nEuh, non, pour être honnête, je n'ai jamais lu aucun des livres que j'étais supposé lire.\nउम् नहीं, सच्च कहु तो मै कभी वे सरे पुस्तक पड़े ही नहीं जो मुझे पढ़ना चाहिए था ।\nНет, честно говоря, я не прочитал ни одной из тех книг, которые должен был.\nUm, hapana, kuwa waaminifu, sijawahi kusoma vitabu vyovyote nilivyotakiwa kusoma.\nอืม ไม่ จริงๆแล้วฉันไม่เคยอ่านหนังสือใด ๆ ที่ฉันควรจะอ่าน\nAa, hayır, dürüst olmak gerekirse, hiç bir zaman okumam gereken kitapları okumam.\nھمم، نہیں، سچ پوچھو تو جو کتابیں مجھے پڑھنی چاہئے تھیں میں نے ان میں سے کبھی کوئی ایک بھی نہیں پڑھی۔\nUm, không, thành thật mà nói, tôi không bao giờ đọc bất kỳ cuốn sách nào mà tôi nên đọc.\n嗯，不，老实说，我从来没读过我应该读的书。\nتقدمت واخترت الأمتعة وذهبت إلى العنوان الذي كان من المفترض أن أذهب إليه.\nПродължих, взех багажа и отидох на адреса, на който трябваше да отида.\nIch ging voran und nahm das Gepäck und ging zu der Adresse, so wie ich sollte.\nΠροχώρησα και πήρα τις αποσκευές και πήγα στη διεύθυνση που έπρεπε να πάω.\nI go ahead and picked the baggage and went to the address I was supposed to.\nContinué y recogí el equipaje y fui a la dirección que se suponía que debía.\nJ'ai continué, j'ai pris mes bagages et je suis allé à l'adresse à laquelle j'étais censé aller.\nमै आगे बढ़कर सामान उठाया और उस पते पर पहुंच गया जहा मुझे जाना चाहिए था ।\nЯ пошел дальше, взял багаж и пришел по адресу, по которому был должен прийти.\nNilikwenda na nikachukua mizigo mbeleni na nilitazama anwani niliyopaswa\nฉันเดินไปข้างหน้าและหยิบสัมภาระขึ้นมาและไปยังที่อยู่ที่ฉันควรจะ\nİlerledim, bagajı aldım ve gideceğim adrese gittim.\nسامان اٹھا کر اس پتہ پر گۂی تھی جہاں جانا تھا\nTôi đi trước và nhặt hành lý và đi đến địa chỉ mà tôi định làm.\n我继续拿起行李，然后前往我应该去的地址。\nكان البديل المزاجي المذهل.\nТова беше невероятна смяна на настроението.\nEs war ein unglaublicher Gefühlsumschwung.\nΉταν μια καταπληκτική εναλλαγή διάθεσης.\nIt was an amazing mood swing.\nFue un cambio de humor increíble.\nC'était un formidable revirement d'humeur.\nमनोदशा में वह एक आश्चर्यजनक बदलाव था।\nЭто приятная перемена настроения.\nKubadilika huko kwa hisia kulikuwa kwa ajabu.\nมันเป็นการปลี่ยนแปลงทางอารมณ์ที่รวดเร็ว ที่นเยี่ยมมาก\nİnanılmaz bir ruh hali.\nیہ مزاج کی ایک شاندار تبدیلی تھی.\nĐó là một tâm trạng tuyệt vời.\n这个情绪上的波动让人惊叹。\nلقد اختاروا لي أكثر من 15 فردا هناك ، للذهاب من خلال تلك المدرسة وأنا لست كذلك ، أنا لست كذلك.\nТе бяха избрали мен сред още 15 души там, да преминат през тази школа, а аз не съм, не съм.\nSie haben mich über etwa 15 Personen ausgewählt, um zu dieser Schule zu gehen, und ich bin es nicht, ich nicht.\nΜε έχουν επιλέξει από πάνω από περίπου 15 άτομα, για να πάω σε εκείνο το σχολείο και δεν έχω πάει, δεν έχω πάει.\nThey've, they chosen me over about 15 individuals there, to go to through that school and I'm not, I'm not.\nMe han... me han elegido entre 15 personas para ir a esa escuela, y no voy a ir, no voy a ir.\nIls m'ont, ils m'ont choisi moi plutôt que 15 individus là-bas, pour aller à cette école, et je n'y vais pas, je n'y vais pas.\nउन्होंने, १५ लोगों से मुझे चुना है उस विद्यालय जाने के लिए है और मैं नहीं हूं, मैं नहीं हूं।\nИ они, они выбрали меня из более чем 15 человек, чтобы пройти через эту школу, но я не собираюсь.\nWao, wamenichagua juu ya watu 15 huko, kwenda kwenye shule hiyo na mimi siendi.\nพวกเขาเลือกฉันมากกว่า 15 คนที่นั้นเพื่อที่จะผ่านไปโรงเรียนนั้นและฉันไม่ใช่ ฉันไม่ใช่\nOnlar, beni oradaki 15 kisi içinden seçtiler, o okula gitmemi istediler ve ben istemedim, istemedim\nانہوں نے مجھے وہاں موجود 15 افراد میں سے چنا، اس اسکول میں جانے کے لئے اور میں نہیں ہوں، میں نہیں ہوں.\nHọ đã chọn tôi trong 15 người ở đó, để đi qua ngôi trường đó và tôi không, tôi thì không.\n他们从那15个人中选择了我去穿过那所学校，我不，我不。\nعليهم المرور عبر عدد من غرف عالية الارتفاع، اه، يلزم الركوب قبل البدء بالطيران بــU2 أو التحليق مع بذلات الضغط.\nТе трябва да преминат през няколко височинни камери, преди да започнат дори да летят с Ю2, или да летят с костюми под налягане.\nSie müssen durch zahlreiche Luftdruckkammern, eh, Fahrten gehen, ehe sie überhaupt U2's fliegen oder mit Druckanzug fliegen.\nΠρέπει να περάσουν από αρκετούς θαλάμους υψόμετρου, μμ, πτήσεις πριν ξεκινήσουν να πετάνε με U2 ή να πετάνε με στολές πίεσης.\nThey have to go through a number of altitude chambers, uh, rides before they start even flying U2's or flying with pressure suits.\nTienen que atravesar una serie de cámaras de altitud, eh, paseos antes de que comiencen incluso a volar en U2 o a volar con trajes a presión.\nIls doivent traverser un certain nombre de chambres d'altitude, de tours avant même de commencer à piloter des U2 ou de voler avec des combinaisons pressurisées.\nवे ऊंचाई कक्षों के अनेक माध्यम से गए है, वह सवारी से पहले ही U2 के उड़ान या दबाव सूट के साथ उड़ान शुरू करते हैं।\nОни должны пройти через несколько высотных камер, да, поездок, прежде чем они начнут просто пилотировать U2 или летать в противоперегрузочных костюмах.\nWanapaswa kupitia vyumba kadhaa vya urefu, uh, hupanda kabla ya kuanza kupaa na U2's au kuruka kwa suti za pumzi.\nพวกเขาต้องผ่านไปยังหมายเลขของห้องชั้นที่สูง อ่าา ขี่ก่อนที่พวกเขาจะเริ่มบิน ยูทู หรือบินไปด้วยชุดสูทแรงดัน\nU2'lerle veya basınçlı giysilerle uçmaya başlamadan önce bile birtakım basınçlı kabinlerden, sürüşlerden geçmeleri gerekiyor.\nU2 (یو 2) کی پرواز شروع کرنے یا پریشر سوٹ کے ساتھ پرواز کرنے سے بھی پہلے انہیں اونچائی کے چیمبر کی کئی سواریوں/پروازوں سے گزرنا پڑتا ہے.\nHọ phải trải qua qua một số buồng áp lực, uh, thậm chí ngay trước khi họ được bắt đầu bay chiếc U2 hoặc bay với bộ đồ chịu áp lực.\n在驾驶U2飞机或穿压力服飞行之前，他们必须进很多次压力舱。\nلكن اه ، فكر في الأمر.\nНо помислете за това.\nAber, äh, denk drüber nach.\nΑλλά, σκεφτείτε το.\nBut uh, think about it.\nPero, mira, piénsalo.\nMais euh, pensez-y.\nलेकिन जरा इसके बारे में सोचो।\nНо ты подумай об этом.\nLakini uh, fikiria juu yake.\nแต่เอิ่ม คิดดูสิ\nAma bunu düşünün.\nپر ہاں، اس کے بارے میں سوچیں.\nNhưng ừm, hãy nghĩ về nó.\n但是，呃，考虑下吧。\nأنا أغطي نفس الأشياء.\nПокривам едни и същи неща.\nIch decke den gleichen Stoff ab.\nΚαλύπτω τα ίδια πράγματα.\nI'm covering the same stuff.\nMe estoy encargando de lo mismo.\nJe couvre la même chose.\nमै भी एक जैसा सामान को कवर कर रहा हु ।\nЯ занимаюсь тем же.\nNinasitara  kixtu hicho hicho pia.\nฉันกำลังปกปิดเรื่องเดิมอยู่\nAynı şeylerden sorumluyum.\nمیں بھی یہی مواد کوور رہا ہوں.\nTôi đang phải che dấu những điều tương tự.\n我在报道同样的事情。\nكانت هذه هي المرة الأولى التي حدث فيها ذلك ، منذ 75 سنة ، أن الهيئة التشريعية في ولاية تكساس قد صوتت لوحدة عسكرية في منصب سفراء تكساس ، لذا فقد احتاجت إلى سفراء تكساس.\nТова беше първият път когато това се случи след 75 години, когато Тексаската законодателна власт гласува военна единица да са тексаските посланици, така че посланици са нужни.\nDies war das erste Mal in 75 Jahren, dass der Staat Texas eine Militäreinheit zu Texas-Botschafter gewählt hat und daher Texas-Botschafter benötigt hat.\nΑυτό συνέβη για πρώτη φορά, αυτό είχε συμβεί σε 75 χρόνια, ο νομοθέτης του Τέξας ψήφισε μια στρατιωτική μονάδα να είναι πρεσβευτές του Τέξας, οπότε έπρεπε να είναι Πρεσβευτές από το Τέξας.\nThis was the first time that, that had happened in 75 years, that TX legislature had voted a military unit being TX Ambassadors, so, needed TX Ambassadors.\nEsta fue la primera vez que, eso había sucedido en 75 años, la legislatura de Texas había votado una unidad militar siendo Embajadores de Texas, por lo tanto, necesitaba embajadores de Texas.\nC'était la première fois que ça arrivait depuis 75 ans, que la législature du Texas avait voté la création d'une unité militaire sous l'autorité d'ambassadeurs du Texas, donc des ambassadeurs du Texas étaient nécessaires.\nये ऐसा ७५ सालो में पहली बार हुआ था , की  टी एक्स  लेजिस्लेचर ने सैनिक  टी एक्स राजदूतों को मतदान दिया , इसी लिए  टी एक्स राजदूतों की आवश्यकता है ।\nВпервые за 75 лет легислатура Техаса проголосовала за избрание представителями Техаса военного подразделения, таким образом, необходимы представители Техаса.\nHii ndio mara ya kwanza ishawahi kufanyika miaka 75 iliyopita vile TX yakutengeneza sheria ilipiga kura kupitisha kikosi ya wanajeshi kuwa mabalozi wa TX, kwa hivyo walitaka mabalozi wa TX.\nนี่เป็นครั้งแรกที่เกิดขึ้นในรอบ 75 ปีที่รัฐสภาเท็กซัสได้ลงคะแนนให้หน่วยทหารที่เป็นเอกอัครราชทูตเท็กซัสดังนั้นจึงต้องการเอกอัครราชทูตเท็กซัส\nTX yasama meclisinin TX elçileri yani ihtiyaç duyulan TX elçileri olmak üzere bir askeri birime oy vermesi 75 yıl içinde ilk kez gerçekleşti.\n'پچتھر سال میں یہ پہلی بار ہوا ہے کہ' ٹی ایکس آۂین نے فوج کو ووٹ دی، ' ٹی ایکس' سفیر ہوتے ہوے  ' ٹی ایکس' سفیر چاھۂے تھے۔\nĐây là lần đầu tiên việc này xảy ra trong vòng 75 năm qua, cơ quan lập pháp Texas đã bỏ phiếu cho một đơn vị quân sự làm các Đại sứ của Texas, vì vậy, cần có các Đại sứ Texas.\n这是75年来的第一次，TX立法机关投票支持一个军事单位作为TX代表，所以才需要TX代表。\nإذا كان هناك أي شيء يمكنني القيام به.\nАко имаше нещо, което можех да направя.\nWenn es nur etwas gäbe, was ich tun könnte.\nΑν μπορούσα να κάνω κάτι.\nIf there was anything I could do.\nSi hubiese algo que pudiera hacer.\nS'il y avait quelque chose que je pouvais faire.\nअगर मैं कुछ भी कर सकता था ।\nЕсли бы было что-то, что я мог сделать.\nKama kuna jambo lolote ambalo naweza kufanya.\nถ้ามีอะไรที่ฉันสามารถทำได้\nYapabileceğim bir şey olsaydı.\nاگر میں کچھ کر سکتی تھی۔\nNếu có bất cứ điều gì tôi có thể làm.\n如果有什么我能帮忙的。\nكرهت ذلك ، وكانت تخبر شقيقتها كل يوم ، وقالت إنك تقوم بعمل خاطئ.\nТя мразеше това и казваше на сестра си всеки ден, че прави грешни неща.\nSie hasste es und sie sagte das ihrer Schwester jeden Tag. Sie sagte dass, dass Du es falsch machst.\nΤο μισούσε αυτό, και καθημερινά λέει στην αδελφή της, εκείνη το είπε, ότι κάνεις λάθος.\nShe hated that, and she used to tell her sister every day, she said that, that You are doing wrong.\nOdiaba eso, y solía contarle a su hermana todos los días, decía eso, que lo estás haciendo mal.\nElle détestait ça et elle avait l'habitude de le dire à sa sœur tous les jours. C'est ce qu'elle a dit, que tu te trompes.\nउसने वह नफरत किया, और वह हर दिन अपनी बहन को बताने के लिए कहती थी, उसने कहा, कि आप गलत कर रहे हैं।\nОна это ненавидела и говорила своей сестре каждый день, что ты ведешь себя неверно.\nAlichukia hilo, na alikuwa akimwambia dada yake kila siku, alisema hivyo, kwamba unafanya vibaya.\nเธอเกลียดที่, และเธอเคยบอกน้องสาวของเธอทุกวัน, เธอบอกว่าคุณกำลังทำผิดนะ\nBundan nefret ederdi ve her gün kız kardeşine söylerdi, Yanlış yapıyorsun derdi.\nوہ نفرت کرتا تھا، اور وہ ہر روز اپنی بہن کو بتاتے تھے، اس نے کہا کہ تم غلط کر رہے ہو.\nCô ghét điều đó, và cô thường nói với chị mình mỗi ngày, cô nói rằng, Chị đang làm sai.\n她很讨厌，而且她每天都会告诉她的姐姐，她说，你做错了。\nأخذوا جو معهم ، وقالت الجدة ، إنه كانت وقتًا حزينًا في المنزل لأن الجميع كانوا يفتقدون جو ولم يعرفوا ماذا يفعلوا.\nВзеха Джо с тях и моята баба каза, че е в къщата е било много тъжно, защото Джо е липсвал на всички и те не са знаели какво да правят.\nSie nahmen Joe mit, und meine Oma sagte, sie sagte, es sei eine so traurige Zeit im Haus, weil jeder Joe vermisse und sie nicht wüssten, was zu tun sei.\nΠήραν μαζί τους τον Τζο και είπε η γιαγιά μου, είπε ότι ήταν τόσο θλιβερές οι στιγμές στο σπίτι, γιατί σε όλους έλειπε ο Joe και δεν ήξεραν τι να κάνουν.\nThey took Joe with them, and my Granny said, she said it was such a sad time in the house because, you know, everybody was missing Joe and they didn't know what to do.\nSe llevaron a Joe con ellos, y mi abuelita dijo que en la casa era un momento triste porque, ya sabes, todo el mundo echaba de menos a Joe y que no sabían qué hacer.\nIls ont emmené Joe avec eux, et ma grand-mère a dit que ça avait été très dur à vivre pour la famille, car Joe manquait à tout le monde et qu'ils ne savaient pas quoi faire.\nवो लोग जो को साथ ले गए और दादी कह रही थी की घर में बड़ी दुखद समय चल रहा था क्यूंकि सब जो की बड़ी याद निकल रहे थे और क्या करे इसका किसी को पता नहीं था |\nОни забрали Джо с собой и моя бабушка сказала... она сказала, что в доме было так грустно, потому что, ну вы понимаете, все скучали по Джо и не знали что делать.\nWalimchua Joe nao na bibi akasema ilikuwa wakati wa huzuni sababu unajua kila mtu alikuwa anamkosa Joe na hawakujua cha kufanya.\nพวกเขาพาโจไปด้วย และยายของฉันบอกว่า เธอบอกว่ามันเป็นเวลาที่แสนเศร้ามาก ๆ ในบ้าน เพราะว่า เธอรู้ไหม ทุกคนคิดโจมากและพวกเขาไม่รู้ว่าจะต้องทำอย่างไร\nJoe'yu yanlarında götürdüler ve büyükannem dedi ki, evde çok hüzünlü zamanlardı çünkü bilirsin, herkes Joe'yu özlüyordu ve ne yapacaklarını bilmiyorlardı.\nوہ جو کو اپنے ساتھ لے کر چلے گۂے ، نانی نے کہا سب اس کو یاد کر کے اداس تھے ، وہ نہیں جانتے تھے کہ کیا کرنا چاھۂے\nHọ đưa Joe đi cùng họ, và Granny của tôi nói, cô ấy nói đó là một khoảng thời gian buồn bã trong nhà bởi vì, bạn biết đấy, mọi người đều mất tích Joe và họ không biết phải làm gì.\n他们带着Joe和他们，我的奶奶说，她说在家里真是太难过了，因为，你知道，每个人都想念Joe，他们不知道该怎么做。\nلذا ذهبت إلى منزلها ثم اتصلت بهذا الرقم الذي كان من المفترض أن أتصل به عندما وصلت هناك.\nЗатова отидох в къщата ѝ, а след това се обадих на номера, на който трябваше да се обадя, когато стигна там.\nAlso ging ich zu ihrem Haus und dann rief ich diese Nummer an, die ich anrufen sollte, als ich dort ankam.\nΈτσι πήγα στο σπίτι της και μετά κάλεσα σε αυτόν τον αριθμό που έπρεπε να πάρω όταν έφτασα εκεί.\nSo I went to her house and then I called in to this number I was supposed to call when I got there.\nAsí que fui a su casa y luego llamé a este número al que se suponía que debía llamar cuando llegué allí.\nDonc je suis allé chez elle et puis j'ai appelé ce numéro que j'étais sensé appeler à mon arrivée chez elle.\nतो मैं उसके घर गया और फिर मैंने इस नंबर पर फोन किया, मुझे फोन करना चाहिए था जब मै वहां गया\nПоэтому я пошел к ней домой, а когда добрался, то позвонил туда, куда должен был позвонить.\nKwa hiyo nilikwenda nyumbani kwake na kisha nikapiga simu kwa nambari hii niliyopaswa kupiga nifikapo.\nฉันก็เลยไปที่บ้านของเธอแล้วฉันก็โทรไปที่หมายเลขนี้ที่ฉันควรจะโทรเมื่อฉันถึงที่นั่น\nBu yüzden evine gittim ve oraya vardığımda aramam gereken bu numarayı aradım.\nذا میں اس کے گھر گیا اور پھر میں نے اس نمبر پر فون کیا تھا جب مجھے وہاں جانے پر فون کیا جانا تھا.\nVậy là tôi đã đến nhà cô ấy và sau đó tôi gọi đến số này, số người ta dặn tôi gọi khi tôi đến đó.\n所以我去了她家，然后我打电话给这个号码。当我到达那里的时候，我应该打电话。\nوصلت إلى هناك هذا الصباح و..لا أتذكر..أعتقد أنه طرخ علي سؤالا ثم دخل إلى هناك، على كل حال.\nАми, пристигнах там тази сутрин и забравил съм как, мисля, че или аз зададох въпрос и той дойде, или както и да е.\nNun, ich komme dort heute morgen an und, äh, ich habe vergessen wie, ich glaube, ich habe entweder eine Frage gestellt und er kam dort hin oder was auch immer.\nΛοιπόν έρχομαι εκεί το πρωί και, ξέχασα πώς, νομίζω ότι είτε έκανα μια ερώτηση και ήρθε εκεί είτε κάτι άλλο.\nWell I get there this morning and um, I forget how, I think either I asked a question and he came in there or, whatever.\nBueno, pues llego allí esta mañana y, eh, olvido cómo... Creo que hice una pregunta, o él entró, o yo qué sé.\neh bien j'arrive ici ce matin et um, j'oublie comment, je pense que soi j'ai posé une question et il vint ici ou, peut importe.\nवैसे मैं आज सुबह वहां जाता हूं, पता नहीं कैसे भूल जाता हूं, पर शायद मैंने उसे सवाल पूछा और वह उधर आया, जो भी हो।\nНу, этим утром иду я туда и, э-э-э, не помню как, наверное, или я задал вопрос и он вошел, или, ну, в общем, ладно.\nVyema nafika pale leo asubuhi na,nasahau ni vipi, nafikiria labda niliuliza swali na akaja humu ndani ama, vyovyote.\nเอ่อ ฉันอยู่ที่นั่นเมื่อเช้านี้และ เอิ่ม ฉันลืมไปว่าอย่างไร ฉันคิดว่าฉันถามคำถามและเขาก็เข้ามาที่นั่นหรือ อะไรก็แล้วแต่\nBu sabah oraya geliyorum ve nasıl olduğunu unuttum, sanırım ya ben bir soru sordum ve o da oraya geldi ya da her neyse.\nبہتر ہے‏، میں وہاں آج صبح جاؤں گا‏، میں بھول گیا کیسے‏، میرا خیال ہے کہ یا تو میں نے ایک سوال پوچھا اور وہ وہاں آیا‏ یا‏، جو کچھ بھی ہو۔\nỒ tôi tới đây sáng nay và ừm, tôi quên mất bằng cách nào, tôi nghĩ chắc tôi đã hỏi một câu hỏi gì đó và anh ta tới đây hoặc là, điều gì đó tương tự.\n我今天早上到那里, 呃，我忘了是我问了一个问题还是他进来了, 随便吧。\nلم نعلم إلى أين كانوا ذاهبين.\nНе знаехме къде отиват.\nWir wussten nicht wo sie hin gegangen sind.\nΔεν ξέραμε πού πήγαιναν.\nWe didn't know where they were going.\nNo sabíamos a donde iban.\nNous ne savions pas où ils allaient.\nहमें नहीं पता था कि वे कहां जा रहे थे।\nМы не знали куда они направляются.\nHatukujua ni wapi walienda.\nเราไม่รู้ว่าพวกเขากำลังจะไปที่ไหน\nNereye gittiklerini bilmiyorduk.\nہم نہیں جانتے تھے کہ وہ کہاں جا رہے تھے۔\nChúng tôi không biết họ sẽ đi đâu.\n我们不知道他们去了哪里。\nلكنني كنت.., إنسى هذا الأمر ، سأقوم لتناول وجبة الغداء فقد كنت جائعاً.\nНо аз си казах, забравих, ще ям обяд, бях гладен.\nAber ich sagte, vergiss es, ich werde Mittag essen, ich war hungrig.\nΑλλά ήμουν σαν, ξέχνα το, θα φάω μεσημεριανό, πεινούσα.\nBut I was like, forget it, I'm going to eat lunch I was hungry.\nPero yo estaba como, olvídalo, me voy a comer el almuerzo. Tenía hambre.\nMais j'étais comme, oublie ça, je vais déjeuner, j'avais faim.\nलेकिन मैं जैसे कि यह भूल गया था, कि मैं दोपहर का खाना खाने जा रहा था लेकिन मैं भूखा था।\nНо я такой, забудь уже, я хочу пообедать я голоден.\nSikufiria tena. Nilikuwa na njaa na nilitaka kula.\nแต่ฉันประมาณว่า, ลืมมันไป, ฉันกำลังจะไปทานข้าวเที่ยง ฉันหิวข้าว\nAma ben, unut gitsin, öğle yemeği yiyeceğim, açım havasındaydım.\nپھر میں نے کہا چوڑو ، میں کھانا کھانے جا رہیں ہوں ، بھوکھ لگ رہی ہے\nNhưng tôi đã như thế, quên nó đi, tôi sẽ ăn trưa tôi đã đói.\n但我想，忘记它，我要去吃午饭，我饿了。\nالجزء كان ، كان هناك 158 جزءًا منه ، وكان علينا أن نكسره وجعله يتراجع معًا ولا يخطئ أبداً.\nЧастта беше, имаше 158 части и трябваше да го разбием, да го съберем заедно, да го счупим и никога да не правим грешка.\nDie Sachel war, es gab 158 Teile dazu und wir mussten es auseinanderbrechen, es wieder zusammenbauen, es kaputt machen und niemals einen Fehler machen.\nΤο σημαντικό ήταν, υπήρχαν 158 τμήματα και έπρεπε να τα σπάσουμε, να τα ξαναβάλουμε, να τα σπάσουμε μαζί και ποτέ να μην κάνουμε λάθος.\nThe part was, there was 158 parts to it and we had to break it down put it back together break it down and never make a mistake.\nLa parte era, había 158 partes y tuvimos que dividirla, volver a unirla, dividirla y nunca cometer un error.\nla chose était, il y avaient 158 morceaux à cela et\nबात यह थी की उसमे १५८ भाग थे और हमे उसे छोटे टुकड़ो में तोड़कर उसे वापस जोड़ना था फिर से तोडना था और इस दौरान कभी गलती की कोई गुंजाइश नहीं थी ।\nТам было 158 деталей, и все их надо было разобрать, затем собрать обратно, затем снова разобрать, и ни разу не ошибиться.\nSehemu ilikuwa, kulikuwa na sehemu mia moja hamsini na nane yake na ilibidi tuivunje na tuiweke pamoja, kuvunja na kutofanya makosa.\nชิ้นส่วนนั้น, มีอยู่ 158 ชิ้น และพวกเราต้องทำให้มันเเตกออกจากกัน รวมมันเข้าด้วยกัน ทำให้มันเเตก และจะต้องไม่ทำพลาด\nBu kısım, 158 kısmı vardı ve onu bölüp tekrar bir araya getirmemiz bölmemiz ve hiç hata yapmamamız gerekiyordu.\nاہم حصہ یہ تھا کہ، اس کے 158 حصے تھے اور ہمیں اسے توڑنا تھا دوبارہ جوڑنا پھر توڑنا تھا بغیر کسی غلطی کے.\nVấn đề là, nó có 158 mảnh và chúng ta phải gỡ chúng ra ghép chung lại gỡ chúng ra và không bao giờ phạm lỗi.\n部分原因是，有158个零件，我们必须打破它，然后再拼回来，打破它，永远不会犯错误。\nحسناً ، ودعني أقل لك أنني وصلت للمرحلة أني كنت على وشك الانسحاب.\nТака че, позволете ми да ви кажа, че днес стигнах до мястото, на което щях да се откажа.\nAlso, lass mich dir sagen, heute bin ich an dem Punkt angelangt, der mich beinahe zum hinschmeissen gebracht hätte.\nΈτσι και άσε με να σου πω ότι έφτασα στο σημείο σήμερα που ήμουν έτοιμος να σταματήσω.\nSo and let me tell you I got to the point today where I was about to quit.\nAsí que, déjame decirte que hoy llegué al punto en el que estuve a punto de renunciar.\nEt donc laissez-moi vous dire je suis arrivé aujourd'hui au point où j'étais sur le point de démissionner\nतो मैं आपको बता देता हूँ कैसे मैं छोड़ने के किनारे तक पहुँचा आज ।\nДа, и еще должен тебе сказать, что сегодня был момент, когда я совсем было собрался уволиться.\nKwa hiyo na niruhusu nikwambie nimefikia hatua leo ambapo nilikuwa karibu kuacha.\nเอาล่ะให้ฉันบอกคุณ ฉันถึงจุดในวันนี้ที่ฉันกำลังจะเลิก\nÖyleyse söyleyeyim, bugün bırakmak üzere olduğum noktaya geldim.\nتو میں تمہیں بتاتا ہوں آج میں اس نقطہ پر پہنچ گیا جہاں میں چھوڑنے والا تھا.\nThế đấy và hãy để tôi nói với bạn rằng tôi đã đến mức mà hôm nay tôi sắp bỏ thuốc lá.\n所以，让我告诉你，今天我正式提出辞职。\nلا أريد الذهاب إلى الـ SS الثالثة فهي سرب الدعم الاستراتيجي الثالث.\nНе искам да влизам в Третата СЕ, която е Третата Стратегическа помощна Ескадрила.\nIch möchte mich nicht näher mit dem Dritten SS befassen, was Dritter Strategischer Support-Schwadron bedeutet.\nΔεν θέλω να μπω στη Τρίτη SS που είναι η Τρίτη Μοίρα Στρατηγικής Υποστήριξης.\nI don't want to go into the Third SS that's Third Strategic Support Squadron.\nNo quiero entrar en la Tercera SS que es el Tercer Escuadrón de Apoyo Estratégico.\nJe ne veux pas entrer dans la troisième SS qui est le troisième escadron de soutien stratégique.\nमैं थर्ड एसएस यानी की तीसरी रणनीतिक सहायता स्क्वाड्रन में नहीं जाना चाहता।\nНе хочу говорить на тему Третьего СС. Это Третий эскадрон стратегического сопровождения.\nSitaki kuenda katika SS ya tatu, hiyo ni kundi la manowari la mkakati wa msaada la tatu.\nฉันไม่ต้องการเข้าไปในกองกำลังที่สามซึ่งเป็นกองกำลังสนับสนุนกองทหารที่สาม\nÜçüncü SS'e, Üçüncü Stratejik Destek Bölüğüne girmek istemiyorum.\nمیں تیسرے ایس ایس میں نہیں جانا چاہتا یہ تیسرا اسٹریٹجک سپورٹ سکواڈران ہے.\nTôi không muốn vào SS thứ ba là Đội hỗ trợ chiến lược thứ ba.\n我不想进入第三战略支援中队。\nلم يكن أي شيء لكن صحراء كان هناك شجر الميرمية بالخارج على الطريق.\nТова не беше нищо друго освен пустиня; на пистата растеше пелин.\nEs gab nichts außer einer Wüste; dort war ein Salbeistrauch draußen auf dem Rollfeld.\nΔεν ήταν τίποτα παρά μια έρημος, υπήρχαν σαγιονάρες στον διάδρομο.\nIt wasn't nothing but a desert; there was sagebrush out on the runway.\nNo era más que un desierto; había artemisa en la pista.\nC'était juste un désert ; la sauge poussait sur la piste d'atterrissage.\nयह कुछ नही बस एक रेगिस्तान था; वहाँ बाहर मार्ग पर सेजब्रश था।\nЭто была не пустота, а пустыня. На взлетной полосе росли кусты полыни.\nHakukuwa kitu bali janga na kuoshwa barabara yote.\nมันไม่ใช่อะไรเลยแต่มันคือทะเลทราย มีพุ่มของต้นเสชบรัชอยู่ข้างนอกรันเวย์\nBir çölden başka bir şey değildi; pistte çalı vardı.\nاور کچھ نہیں ، ریگستان تھا ، رن وے پر سرف جھاڑی نظر آں رہے تھے ۔\nChẳng có gì ngoài sa mạc; có cây đan sâm trên đường đi.\n那里不过是一片沙漠，跑道上有灌木。\nولد في عام 1880, أو شئ مثل 188, أظن أنها كانت 1889 ، أعتقد أنها كانت كذلك عندما ولد.\nТой е роден през 1880-те, мисля 188, или 1889, мисля, че тогава е роден.\nEr wurde in den 1880igern geboren, wie 188, ich glaube 1889, Ich glaube da wurde er geboren.\nΓεννήθηκε κάπου στο 1880, κάπως 188, νομίζω ότι ήταν το 1889, νομίζω ότι ήταν όταν γεννήθηκε.\nHe was born in 1880 something, like 188, I think it was 1889, I think it was when he was born.\nNació en mil ochocientos ochenta y algo, pudo haber sido en 1888 o 1889.\nIl est né en 1880, quelque chose comme 188..., je crois que c'était en 1889, je crois que c'était juste après sa naissance.\nउनका जन्म १८८० में हुआ था, जैसे १८८, मुझे लगता है कि यह १८९८ था, मुझे लगता है कि उसी समय वह पैदा हुआ था।\nОн родился в 1880 году, что-то вроде 188. Я думаю, что это было в 1889 году. Я думаю, что это был год, когда он родился.\nAlizaliwa kipindi cha mwaka wa 1988-1989 hivi nikidhani.\nเขาเกิดในปี 1880 บางอย่าง,เช่น 188 ผมคิดว่าปี 1889, ผมคิดว่ามันคือปีที่เขาเกิดมา\n1880'li bir tarihte doğdu, 188 gibi, sanırım 1889'du, galiba doğduğu zaman buydu.\nوہ 1880 میں کچھ 188 کی طرح پیدا ہوا تھا،مجھے لگتا ہے کہ یہ 1889 تھا، میرا خیال ہے کہ وہ پیدا ہوا تھا.\nAnh ấy sinh năm 1880, như 188, tôi nghĩ đó là năm 1889, tôi nghĩ đó là khi anh ấy được sinh ra.\n他出生在1880年，好像是188，我或者是1889年，我想他应该是那时候出生的。\nمن الأفضل أن تلف المسمار الصغير للأسفل قليلاً لأنك يمكن أن تدمر رئة الشخص بسهولة شديدة.\nПо-добре да завиеш малкото винтче малко, защото може лесно да увредиш белите дробове на всеки.\nDu drehst die kleine Schraube besser etwas nach unten weil du könntest sonst beschädigen und anderen ihre lunge sehr leicht\nΘα ήταν προτιμότερο να αναστρέψετε λίγο το μικρό κοχλία γιατί θα μπορούσατε να βλάψετε και τους πνεύμονες του ατόμου πολύ εύκολα.\nYou'd better get the little screw turned down a little bit because you could damage and individual's lungs very easily.\nSerá mejor que baje un poco el pequeño tornillo porque podrías lastimar los pulmones de una persona fácilmente.\nTu ferais bien de pencher un peu la vis parce que tu pourrais endommager les poumons individuels très facilement.\nआप को चाहिए कि वह छोटा पेंच थोडा सा नीचे करें क्यों कि आप व्यक्ति के फेफडों को एकदम जल्दी नुकसान पहुँचा सकते है।\nВам лучше немного повернуть маленький винт, потому что вы можете без труда повредить легкие человека.\nIngekuwa bora ungegeuza chini skrubu ndogo kidogo kwa sababu unaweza kuharibu mapafu ya mtu kwa urahisi sana.\nคุณควรจะขันสกรูปิดลงอีกนิดเพราะว่าคุณอาจจะทำลายปอดแต่ละอันอย่างง่ายดายมาก\nKüçük vidayı biraz gevşetirsen iyi olur, çünkü bireyin yaralanmasına neden olur ve akciğerlerine çok kolay bir şekilde zarar verebilirsin.\nآپ کے لئے چھوٹے پیچ کو تھوڑا نیچے کرنا بہتر رہے گا کیونکہ آپ بہت آسانی سے کسی کے پھیپھڑے خراب کر سکتے ہیں.\nTốt hơn là bạn nên giảm bớt một chút vì bạn có thể gây tổn thương phổi của mọi người rất dễ dàng.\n你最好把这个小螺丝拧低一点，不然很容易伤到某个倒霉蛋的肺部。\nتلقيت أوامر للذهاب إلى ديل ريو ، تكساس ، لذلك عندما وصلت هناك ، حسنا ، اكتشفت أنني يجب أن أذهب إلى قاعدة سلاح الجو في لافلين.\nПолучих заповеди да отида в Дел Рио, Тексас, така че когато пристигнах там, разбрах, че трябва да отида до базата на военновъздушните сили в Лафлин.\nIch bekam den Befehl, nach Del Rio, Texas, zu fahren. Als ich dort ankam, wusste ich, dass ich zur Laughlin Air Force Base musste.\nΈλαβα εντολές να πάω στο Del Rio, TX, οπότε όταν έφτασα εκεί, ανακάλυψα ότι έπρεπε να πάω στη βάση της Πολεμικής Αεροπορίας Laughlin.\nI got orders to go to Del Rio, TX, so when I arrived out there, well, I found out I had to go to Laughlin Air Force Base.\nRecibí órdenes para irme a Del Rio, TX, así que cuando llegué allí, bueno, me enteré de que tenía que irme a la base de Laughlin Air Force.\nJ'ai reçu l'ordre d'aller à Del Rio au Texas, donc quand je suis arrivé là-bas, eh bien j'ai appris que je devais aller à la base aérienne de Laughlin.\nमुझे डेल रियो, टेक्सास जाने का आदेश मिला, इसलिए जब मैंने वहां पहुंचा, मुझे पता चला कि मुझे लाफ्लीन एयर फोर्स बेस में जाना था।\nЯ получил приказ отправиться в Дель-Рио, Техас, поэтому, когда я приехал туда, ну, я узнал, что я должен отправиться на базу ВВС Лафлина.\nNilipata maagizo ya kwenda Del Rio, TX, hivyo wakati nilipofika huko, niligundua nilipaswa kwenda kwa Laughlin Air Force Base.\nฉันได้รับคำสั่งที่จะไปที่ Del Rio, TX ดังนั้น เมื่อฉันไปถึงที่นั่นแล้ว, อืม, ฉันพบว่าฉันต้องไปที่ฐานทัพอากาศ Laughlin\nDel Rio, Texas'a gitme emri aldım, haliyle, oraya vardığımda Laughlin Hava Kuvvetleri Üssü'ne gitmek zorunda olduğumu anladım.\nمجھے ڈیل ڑیو جانے کا حکم دیا گیا تھا ، جب وہاں پہنچا  تو پتا چلا کے لافلن ایئر فورس بیس جانا پرے گا۔\nTôi có những đơn hàng đến Del Rio, TX, nên khi tôi đến đó, well, tôi nhận ra rằng tôi phải đến Doanh trại Không quân Laughlin.\n我接到了去德克萨斯州德尔里奥的命令，所以当我到达那里的时候，我发现我必须去劳克林空军基地。\nسوف يحصل كل واحد على شامبانيا وبعض الناس لا يشربونها وقد شرب الأطفال ما تبقى لهذا فقد شربنا جميعاً كل الشامبانيا.\nВсеки получава шампанско и някои хора не го пият, така че това, което остава, го пият децата, така че ние обикаляхме и пиехме всичкото това шампанско.\nJeder bekommt Champagner und einige Leute trinken es nicht, also was bleibt, trinken die Kinder, also gingen wir herum und tranken all diesen Champagner.\nΌλοι παίρνουν σαμπάνια και μερικοί άνθρωποι δεν την πίνουν, έτσι αυτό που έμεινε στα παιδιά να πιουν γι 'αυτό πηγαίναμε γύρω γύρω πίνοντας όλη αυτή τη σαμπάνια.\nEverybody gets champagne and some people don't drink it so what's left the kids drink so we were going around drinking all this champagne.\nTodo el mundo se sirve champán, pero no todos se terminan lo que se sirven; lo que queda en las copas se lo beben los niños, así que allí estábamos nosotros, dando vueltas y bebiendo todo ese champán.\nTout le monde reçoit du champagne et certaines personnes n'en boivent pas. Alors ce qui reste, les enfants le boivent. Donc nous nous baladions pour boire tout ce champagne.\nहर कोई शैम्पेन पीता  है और कुछ इसे नही पीते हैं तो जो बच जाती है उसे बच्चे पीते हैं, तो हम घूम-घूमकर इस बची हुई शैम्पेन को पी रहे थे।\nВсе получают шампанское и некоторые его не пьют, поэтому то, что остается, выпивают дети, поэтому мы ходим, пьем все это шампанское.\nKila mtu huwa yuwapewa mvinyo lakini wengine hawanywi na kwa hivyo kinachobakia watoto hunywa. Huwa tunaenda kila mahali tukinywa mvinyo.\nทุกคนได้แชมเปญและบางคนก็ไม่ดื่มมัน สิ่งที่เหลือเด็ก ๆ ก็ดื่ม ดังนั้นเราจึงเดินไปทั่วเพื่อดื่มแชมเปญทั้งหมดนี้\nHerkes şampanya alır ve bazı insanlar içmez, böylece çocuklar ne içerse o kadar şampanya içeriz.\nسب لوگ شیمپین ہو جاتے ہیں اور کچھ لوگ اس کو پیتے نہیں کرتے ہیں لہذا بچوں کو پینے سے بچا جاتا ہے لہذا ہم اس شیمپین کو پینے کے ارد گرد جا رہے تھے.\nMọi người đều uống rượu sâm banh và một số người không uống nó vì vậy những gì còn lại cho trẻ em uống vì vậy chúng tôi đã đi xung quanh uống tất cả rượu sâm banh này.\n每个人都拿到了香槟，有些人不喝，所以孩子们都喝剩了，所以我们就到处走喝了所有剩下的香槟。\nهناك العديد من القصص في المدينة العارية.\nИма толкова много истории в голия град.\nEs gibt viele Geschichten um die nackte Stadt\nΥπάρχουν πολλές ιστορίες στην γυμνή πόλη.\nThere are many stories in the naked city.\nHay muchas historias en la ciudad desnuda.\nIl y a de nombreuses histoires dans la cité sans voiles.\nइस खुले शहर में कईँ कहानियाँ हैं।\nВ голом городе много историй.\nKuna hadithi nyingi fiche katika jiji.\nมีเรื่องราวมากมายในเมืองที่ว่างเปล่า\nÇıplak şehirde pek çok hikaye var.\nبرہنہ شہر میں کئی داستانیں ہیں.\nCó quá nhiều câu chuyện trong thành phố bỏ hoang này.\n在这个赤裸城市里有许多故事。\nأنت تعيش وتتعلم كما تعلم، عندما تختبر الطائرة.\nТи живееш и се учиш чрез тестване на самолета.\nWeißt du, du lebst und lernst wenn du Flugzeuge testest.\nΖεις και μαθαίνεις, ξέρεις, όταν δοκιμάζεις, αεροσκάφη.\nYou live and learn, you know, when you test, uh, aircraft.\nSe vive y se aprende, ya sabes, cuando pruebas, eh, avión.\nTu vis et tu apprends, tu sais, lorsque tu testes, oh, un avion.\nआप जागते हो  और सीखते हैं, आप जानते हैं, जब आप परीक्षण करते हैं, उह, विमान का |\nНу, знаете ли, когда вы испытываете воздушные судна, всегда есть что-то доселе неизвестное.\nUnaishi na kujifunza, unajua, Unapojaribu ndege.\nคุณมีชีวิตและเรียนรู้,คุณรู้ไหมว่า,เมื่อคุณทดสอบ,อ๊ะ,เครื่องบิน\nBilirsin, uçakları test ettiğinde yaşıyor ve öğreniyorsun.\nآپ زندہ رہتے ہیں اور سیکھتے ہیں جب آپ ہوائی جہاز کو ٹیسٹ کرتے ہیں.\nBạn sống và học tập, bạn biết đấy, khi bạn kiểm tra, ôi, máy bay.\n你知道，当你在测试，呃，飞机时，你在生活和学习。\nكان هذا هو الهدف\nИскам да кажа, че в това беше цялата работа.\nIch meine, das war der ganze Punkt.\nΑυτό είναι το όλο θέμα.\nI mean that was the whole point.\nQuiero decir que ese fue el punto completo.\nJe veux dire que c'était tout le problème.\nमेरा मतलब है कि यह पूरी बात थी\nЯ имею в виду, что это и был весь смысл.\nNamaanisha hiyo ndio ilikuwa fikra nzima.\nฉันหมายความว่านั่นคือประเด็นทั้งหมด\nYani bütün nokta buydu.\nمیرا مطلب یہ تھا کہ پوری بات.\nÝ tôi là đó là toàn bộ vấn đề.\n我的意思是，这就是重点。\nكان من قاعدة جوية حلقت فوق كوبا ، وبالطبع تم إسقاط رودولف أندرسون.\nТой беше от военновъздушна база и прелетя над Куба, и разбира се, Рудолф Андерсън беше свален.\nEs war von einem Luftwaffenstützpunkt, der über Kuba geflogen ist, und natürlich wurde Rudolph Anderson abgeschossen.\nΉταν από μια αεροπορική βάση που πέταξε πάνω από την Κούβα, και φυσικά ο Rudolph Anderson καταρρίφθηκε.\nIt was from a airbase that flew over Cuba, and of course Rudolph Anderson was shot down.\nFue desde una base aérea que pasaba sobre Cuba, y por supuesto derribaron a Rudolph Anderson.\nC'était à partir d'une base aérienne qu'il a survolé Cuba, et bien sûr Rudolph Anderson a été abattu.\nवह क्यूबा के ऊपर से जाने वाली हवाई अड्डे से थी , और बेशक रुडोल्फ एंडरसन मारे जा चुके थे ।\nЭто было с авиабазы, она полетела над Кубойи и, конечно же, Рудольф Андерсон был сбит.\nIlikuwa kutoka uwanja wa ndege uliopitia Cuba na ambapo Rudolph Anderson alipigiwa risasi.\nมันมาจากฐานทัพอากาศที่บินไปคิวบา และแน่นอนว่า Rudolph Anderson ถูกยิงตกลงมา\nKüba'nın üzerinden geçen, bir hava üssünden geliyordu ve Rudolph Anderson tabii ki vuruldu.\nyeh aik airbase se tha jo cuba kai upar se guzra, aur yaqeenan Ruduloph Anderson ko mar gira diya gaya.\nAnh ta bay đến Cuba từ sân bay, và dĩ nhiên Rudolph Anderson bị bắn hạ.\n它来自一个飞越古巴的空军基地，当然鲁道夫·安德森被击落了。\nلم أقل ذلك مرة أخرى لهذا فقد ضغط عليا وأحرجني لهذا لا اريد أن أعرف متى نحتاجها.\nТой не каза отново време, така че той просто ме закара там, а аз дори не знам кога е необходимо.\nEr sagte keine Zeit mehr, also hat er mich dort sorgend gelassen und ich weiß nicht mal wenn es gebraucht wird.\nΔεν έλεγε ξανά , γι 'αυτό με πήγε εκεί πέρα κι αγχωνόμουνα, και πραγματικά δεν ξέρω καν πότε χρειάζεται\nHe didn't say a time again, so he just got me over there stressing out, and I really don't even know when it's needed.\nNo volvió a decir una hora, así que me hizo ir allí estresado, y ni siquiera sé cuándo es necesario.\nEncore une fois, il n'a pas donné d'heure, donc il m'a fait encore plus stresser, et je ne sais vraiment pas quand est-ce qu'il le faut.\nउसने फिर एक बार नहीं दोहराया , इसी लिए उसने मुझे वहा तनाव में दाल दिया, और मुझे ये भी नहीं पता की कब इसकि ज़रूरत पड़ेगी ।\nОн не повторил время снова, он просто доставил меня туда и я на сомом деле даже не знаю, когда это требуется.\nYeye hakusema wakati tena, kwa hiyo alinipata hapo yu nikisisitiza , na sijui hata wakati itakapohitajika.\nเขาไม่ได้พูดเวลาอีกครั้ง, ดังนั้นเขาเพียงแค่ให้ฉันไปที่นั่นยืดเส้นยืดสาย, และฉันก็ไม่รู้ด้วยซ้ำว่าเมื่อไหร่ที่จะต้องการมัน\nTekrar bir zaman bile demedi, o yüzden beni oraya zorladı, ve ne zaman gerekli olduğunu bile bilmiyorum.\nانہوں نے پھر کوۂی وقت نہیں بتتایا ، اور میں یہاں پریشان بیٹھی ہوں کیونکہ میں نہیں جانتی کے ان کو کب ضرورت ہوگی\nAnh ấy không nói lại lần nữa, vì thế anh ấy bắt tôi ở đó vật vờ, và tôi thậm chí không biết khi nào thì cần thiết.\n他没有再说一遍，他只是留我在那里，我的压力很大，我甚至不知道什么时候要。\nربما كان هذا هو أول شيء أتذكره من كوني صبي صغير ، آه ، خاصة حول شيء قمت به بشكل خاطئ.\nТова вероятно беше първото нещо, което си спомням, когато бях малко дете, особено ако съм направил нещо нередно.\nDas ist wahrscheinlich das erste woran ich mich aus meiner frühen Kindheit erinnere, äh, vor allem wenn ich etwas falsch gemacht hatte.\nΉταν ίσως το πρώτο πράγμα που θυμάμαι από τότε που ήμουν μικρό παιδί, αχ, ειδικά για κάτι που είχα κάνει λάθος\nIt was probably the first thing I remember from being a little kid about, ah, especially about something that I'd done wrong.\nProbablemente es la primera cosa que recuerdo de cuando era un niño pequeño, especialment4e cuando había hecho algo malo.\nC'était probablement la première chose dont je me souvenais de ma petite enfance, et en particulier au sujet d'une bêtise.\nयह शायद पहली बात है जो मेरे याद मे आता है मैं एक छोटा बच्चा होने के बारे मे, खासतौर पर कुछ गलत करने के बारे मे।\nВероятно, это была первая вещь, о которой я помню, будучи маленьким ребенком,.. ах, особенно о чем-то, что я сделал неправильно.\nHuenda ni jambo la kwanza nakumbuka kutoka kuwa mtoto mdogo, hasa juu ya kitu ambacho nilifanya makosa.\nมันอาจเป็นสิ่งแรกที่ฉันจำได้นับจากเมื่อตอนเป็นเด็กเล็ก ๆ อา โดยเฉพาะอย่างยิ่งเกี่ยวกับบางสิ่งที่ฉันอาจจะทำผิดไป\nMuhtemelen küçük bir çocuk olmakla, ah, özellikle de yanlış yaptığım bir şeyle ilgili hatırladığım ilk şey buydu.\nیہ میرے بچپن کی شاۂد پہلی یاد ہے ، خاص طور پر کیونکہ میری غلطی تھی\nĐó có lẽ là điều đầu tiên tôi nhớ là một đứa trẻ, ah, đặc biệt là về điều gì đó mà tôi đã làm sai.\n这可能是我从小记得的第一件事，啊，尤其是我做错的那些事。\nإنهم فقط ليسو كما كانت ترغب  أن يكونو سود فى هذه الايام ، ولهذا كان، فأنت تعرف ، أنا أعتقد ، بأنه ، بأنه كان من الممكن ، أنت تعرف ، فى بداية ١٩٣٠ ، أوه عندما فعلو ذلك .\nНа тях просто не им харесваше какво е да си чернокож в онези дни и това беше, мисля, в началото на 30-те години на 20-ти век, когато го направиха.\nSie haben einfach nicht gemocht, wie es damals war, schwarz zu sein, und das war, weißt du, ich vermute, das war wahrscheinlich, weißt du, in den frühen 1930er Jahren, als sie das taten.\nΑπλά δεν τους άρεσε το πως ήταν να είσαι μαύρος εκείνη την εποχή, και αυτό ήταν, ξέρετε, υποθέτω, αυτό ήταν μάλλον, ξέρετε, στις αρχές της δεκαετίας του 1930, όταν το έκαναν αυτό.\nThey just didn't like what it was like being black in those days, and that was, you know, I guess, that, that was probably, you know, in the early 1930s, uh, when they did that.\nSimplemente no les gustó lo que era ser negro en esos días, y eso fue, ya sabes, supongo que, eso fue probablemente, ya sabes, a principios de la década de 1930, cuando lo hicieron.\nIls n'aimaient simplement pas ce que c'était que d'être noir à l'époque, et c'était, vous savez, je suppose, que c'était probablement, vous savez, au début des années 1930, quand ils l'ont fait.\nउन दिनों में वे काले रंग के होने की तरह नहीं थे, और वह था, आप जानते हैं, मुझे लगता है, कि, संभवतः, आप जानते थे, 1 9 30 के शुरुआती दिनों में, उह, जब वे ऐसा करते थे।\nИм просто не нравилось то, что означало в тех дни быть черным, и ето было, ер помоймо, ето было, вероятно, ну в в начале 1930-х годов, ер, когда они ето совершили\nWao hawakupenda tu kama ilivyokuwa nyeusi siku hizo, na hilo lilikuwa, wajua, nadhani, kwamba, hilo lilikuwa labda, wajua, mwanzoni wa miaka ya 1930, uh, walivyofanya hivyo.\nพวกเขาแค่ไม่ชอบอะไรที่ค่อนข้างขุ่นมัวในวันโน้น และนั่นแหละ คุณรู้ ฉันเดานะ นั่นอาจจะเป็นไปได้ คุณรู้นะ ช่วงต้น ๆ ปี ค.ศ. 1930 อ่าา เมื่อพวกเขาทำสิ่งนั้น\nSadece o günlerde siyah olmanın nasıl bir şey olduğunu sevmediler, ve bilirsin, bilirsin, muhtemelen 1930'ların başlarında bunu yaptıla\nان دنوں جس طرح سیاہ فام لوگوں کی زندگی ہوا کرتی تھی انہیں بس وہ پسند نہیں تھی، اور وہ، آپ کو معلوم ہی ہے میرا خیال ہے، یہ شاید، آپ کو پتہ ہوگا، 1930 کے اوائل کی بات ہے، اہ، جب انہوں نے یہ کیا تھا۔\nHọ chỉ không thích những gì tương tự như màu da đen vào thời ấy, và như bạn biết đấy, tôi đoán, đó có lẽ, như bạn biết đấy, vào đầu những năm 1930, ừ, khi họ đã làm điều đó.\n在那个时代，他们不喜欢黑人的所作所为，而且，你知道，我想，也许，你知道，在20世纪30年代早期，呃，他们是如何做的。\nهو من اليونان وهو من قرية صغيرة في اليونان تعرف باسم توكاليكا وقد أتى إلى أمريكا وأعتقد أن ذلك كان عام 1969 أو 1970 وتزوج بعد ذلك بفترة قصيرة.\nТой е от Гърция, от едно малко селце в Гърция, наречено Токалека, и е дошъл в Америка и мисля, че е било 1969 или 1970 г., и скоро след това се оженил.\nEr kommt aus Griechenland und er kommt aus einem kleinen Dorf in Griechenland namens Tokaleka und er kam nach Amerika und ich glaube es war 1969 oder 1970 und er heiratete kurz darauf.\nΕίναι από την Ελλάδα και είναι από ένα μικρό χωριό στην Ελλάδα που ονομάζεται Τοκαλέκα και ήρθε στην Αμερική. Πιστεύω ότι ήταν το 1969 ή το 1970 και σύντομα παντρεύτηκε.\nHe is from Greece and he is from a small village in Greece called Tokalleka and he came to America and I believe it was 1969 or 1970 and he shortly got married.\nÉl es de Grecia y él es de un pequeño pueblo en Grecia llamado Tokalleka, vino a América, creo que fue en 1969 o 1970  y se casó pronto.\nIl vient de Grèce, d'un petit village en Grèce appelé Tokalleka, et il est venu en Amérique je crois que c'était en 1969 ou 1970, et il s'est marié rapidement.\nवह ग्रीस से है और ग्रीस के टोकलेका नामक छोटे से गांव से है और वह अम्रिका आया और मेरा यह मानना है की सं १९६९ या १९७० की बात है और जल्द ही उसकी ब्याह हो गई थी |\nОн из Греции, из маленькой деревни в Греции под названием Токаллека, и он приехал в Америку, по-моему, в 1969 или 1970 году и вскоре женился.\nYeye anatoka Ugiriki kutoka katika kijiji kidogo cha Ugiriki kinachoitwa Tokalleka na alikuja Amerika na ninaamini ilikuwa mwaka wa 1969 au 1970 na hakukawia kabla ya kuoa.\nเขามาจากประเทศกรีกและเขามาจากหมู่บ้านเล็ก ๆ ในกรีกชื่อว่า Tokalleka และเขามาที่อเมริกาและฉันเชื่อว่ามันคือปี ค.ศ. 1969 หรือ 1970 และในไม่นานเขาก็แต่งงาน\nO Yunanlı ve Yunanistan'da Tokalleka adında küçük bir köyden Amerika'ya gelmiş ve bunun 1969 veya 1970 olduğuna inanıyorum ve kısa bir süre sonra evlenmiş.\nاس کا تعلق یونان سے ہے اور وہ یونان کے ایک چھوٹے سے گاؤں سے تعلق رکھتا ہے جس کا نام ٹوکل لیکا ہے اور وہ امریکا آیا تھا اور میرا خیال ہے یہ 1969 یا 1970 تھا اور اس کے تھوڑی دیر بعد اس نے شادی کر لی.\nAnh ấy đến từ Hy Lạp và từ một ngôi làng nhỏ tên là Tokalleka và anh ấy đến Mỹ và tôi tin rằng đó là năm 1969 hoặc 1970 và chỉ sau đó thời gian ngắn anh ấy đã lập gia đình.\n他来自希腊，他来自希腊的一个叫Tokalleka的小村庄，我相信他是在1969或1970年来美国的，并且他很快就结婚了。\nعلى أيه حال ، عاودت الاتصال ب رومانا لأسئلها عن ، كنت مثل ، حسنا ، دعنى اسرع معها ، وكان لدى سؤال عن شئ ما كنت أفعله .\nТака или иначе, се обаждам на Рамона, защото имах въпрос, и бързах да го задам, и имах въпрос за нещо, което правех.\nJedenfalls, ich habe Ramona zurückgerufen, weil ich eine Frage hatte, ich dachte mir: Okay, lass mich damit weitermachen und ich hatte eine Frage über etwas, das ich gemacht habe.\nΈτσι τέλος πάντων, καλώ τη Ramona επειδή είχα μια ερώτηση, ήμουν σαν, εντάξει, επιτρέψτε μου να βιαστώ με αυτό, και είχα μια ερώτηση για κάτι που έκανα.\nSo anyhow, I call Ramona back because I had a question about, I was like, All right, let me hurry up on with it, and I had a question about something I was doing.\nAsí que devolví la llamada a Ramona porque tenía una pregunta sobre… vale… déjame acabar… tenía una pregunta sobre una cosa que estaba haciendo.\nDonc de toute façon, je rappelle Ramona parce que j'avais une question à lui poser et je disais d'accord, finissons-en vite avec ça, et j'avais une question à propos de quelque chose que j'étais en train de faire.\nफिर, जो  भी हो , मैंने रमोना को कॉल किया क्युकी मेरे कुछ सवाल थे , और मै था की , ठीक है, जल्दी जल्दी कर देता हु , मुझे कुछ और सवाल करने थे उसके बारे में जो मै कर रहा था अभी ।\nВ любом случае я позвал Рамону обратно, потому что у меня был вопрос к ней о том, какой я. Хорошо, позволь мне поторопиться с этим; и еще был вопрос о том. что я делаю.\nHata hivyo, nampigia tena Ramonakwa sababu nilikuwa na swali. Alafu nikamweleza kwamba acha niharakishe kwa sababu nilikuwa na swali kuhusu jambo fulani ambalo nilikuwa nafanya.\nดังนั้นไม่ว่าจะอย่างไร ฉันเรียกราโมน่ากลับมาเพราะว่าฉันมีคำถาม ฉันแบบว่า เอาล่ะ ให้ฉันรีบเร่งมือกับมันหน่อย และฉันมีคำถามเกี่ยวกับสิ่งที่ฉันทำอยู่\nHer neyse, Ramona'yı geri aradım çünkü soracağım bir şey vardı, ben şey, pekala, toparlayayım dedim, yaptığım şey hakkında bir sorum vardı.\nتو میں نے ریمونہ کو فون کیا کیونکہ مجھے کچھ پوچھنا تھا، میں نے کہا چلو جلدی سے پوچھتی ہوں ، مجھے اپنے کام کے بارے میں ایک سوال ہے۔\nVì vậy, dù sao đi nữa, tôi gọi Ramona trở lại vì tôi có một câu hỏi về, tôi có cảm giác, Được rồi, hãy để tôi thực hiện nhanh và tôi có một câu hỏi về điều gì đó mà tôi đang làm.\n所以无论如何，我给雷蒙娜回电话，因为我有一个问题，就像是，好吧，让我赶快搞定它，而且我对我在做的事情有疑问。\nلا أحد يعرف إلى أين ذهبوا.\nНикой не знаеше къде отидоха.\nNiemand wusste, wo sie hingegangen waren.\nΚανένας δεν ήξερε που πήγαν\nNobody knew where they went.\nNadie sabía a dónde iban.\nPersonne ne savait où ils étaient.\nकोई नहीं जानता था कि वे कहाँ गए|\nНикто не знал, куда они ушли.\nHakuna aliyejua walipokwenda.\nไม่มีใครรู้ว่าพวกเขาไปที่ไหน\nNereye gittiklerini kimse bilmiyordu.\nکوئی بھی نہیں جانتا تھا کہ وہ کہاں گئے تھے.\nKhông ai biết chúng tôi đã đi đâu.\n没有人知道他们去哪里了。\nولم يتمكنوا من البقاء في منطقة أوحوستا بسبب علم الناس بأنهم حاولوا فعل شيء محرم جداً، وحاولوا أن يبدوا من البيض.\nИ те не можеха да останат в района на Огъста, защото хората знаеха, че са се опитали да направят нещо, което наистина беше табу, и са искали да минат за бели.\nUnd sie konnten nicht in der Augusta Region bleiben, weil die Menschen dort wussten, dass sie versucht hatten etwas sehr tabuisiertes zu tun, und versucht hatten als Weiße durchzugehen.\nΚαι δεν μπορούσαν να μείνουν στην περιοχή της Augusta επειδή οι άνθρωποι ήξεραν ότι είχαν προσπαθήσει να κάνουν κάτι που ήταν πραγματικά ταμπού και προσπάθησαν να περάσουν για λευκοί.\nAnd they couldn't stay in the Augusta area because people knew that they had tried to do something that was really taboo and try to pass for white.\nY no podían quedarse en el área de Augusta porque la gente sabía que habían intentado hacer algo que era realmente tabú y tratar de pasar por blanco.\nEt ils ne pouvaient pas rester dans la région d'Alberta parce que les gens savaient qu'ils avaient vraiment enfreint un tabou et qu'ils avaient essayé de passer pour Blancs.\nऔर वे अगस्ता क्षेत्र में नहीं रह सकते क्योंकि लोग जानते थे कि उन्होंने कुछ ऐसा करने की कोशिश की थी जो वास्तव में निषिद्ध था और वाइट के लिए पार करने का प्रयास किया था।\nИ им нельзя было оставаться в районе города Огаста, так как местные уже знали об их попытке нарушить табу и выдать себя за белых.\nNa hawakuweza kukaa eneo la Augusta kwa sababu watu walijua kuwa walijaribu kufanya kitu ambacho kilikuwa mwiko na kujaribu kupitisha nyeupe.\nและพวกเขาก็ไม่สามารถอยู่ในเขต Augusta ได้เพราะว่าผู้คนรู้ว่าพวหเขาได้พยายามทำบางสิ่งที่เป็นข้อห้ามและพยายามส่งต่อให้กับคนขาว\nVe onlar Augusta bölgesinde kalamazlardı çünkü insanlar gerçekten tabu olan bir şey yapmayı denediklerini ve beyaz için geçmeye çalıştıklarını biliyorlardı.\nاور وہ آگسٹا کے علاقے میں نہیں رہ سکتے تھے کیونکہ لوگ جانتے تھے کہ انہوں نے ایسا کچھ کرنے کی کوشش کی تھی جو بہت ہی برا سمجھا جاتا تھا اور سفید فام بننے کی کوشش کی تھی۔\nVà họ không thể ở lại khu vực Augusta bởi vì mọi người biết rằng họ đã cố gắng làm một điều cấm kỵ và cố gắng biến địa vị của mình thành người da trắng.\n他们不能呆在奥古斯塔地区，因为人们知道他们曾经尝试过做一些被禁的事情，并试图假装白人。\nنعم ، حسنا ، الشاب هنا\nДа, момчето е тук.\nJa, nun, er ist hier.\nΝαι, ο τύπος είναι εδώ.\nYeah, well, the guy's here.\nSí, bueno, el chico está aquí.\nOuais, eh bien, le mec est là.\nहाँ, ठीक है, आदमी यहाँ है।\nНу что ж, он здесь.\nNdio, vyema, mtu huyo amekuja.\nนั่นแหละ เจ้าตัวมาแล้ว\nEvet, adam burada.\nجی ہاں، ٹھیک ہے، لڑکا یہاں ہے.\nYeah, vâng, gã ta đang ở đây.\n耶，那个人在这。\nكنا نشاهد شيئًا على التلفزيون.\nГледахме нещо по телевизията.\nWir haben etwas im Fernsehen angeschaut.\nΒλέπαμε κάτι στην τηλεόραση.\nWe were watching something on TV.\nEstábamos viendo algo en la tele.\nNous regardions quelque chose à la télé.\nहम टीवी पर कुछ देख रहे थे |\nМы смотрели что-то по телевизору.\nTulikuwa tukitazama kitu kwenye Televisheni.\nพวกเรากำลังรับชมบางอย่างในทีวี\nTelevizyon'da bir şeyler izliyoruz.\nہم ٹی وی پر کچھ دیکھ رہے تھے۔\nChúng tôi đang xem gì đó trên TV.\n我们正在看电视。\nعلى أي حال ، أعتقد أنني تحدثت مع رامونا مرة أخرى.\nТака или иначе, мисля, че отново говорих с Рамона.\nJedenfalls denke ich, dass ich wieder mit Ramona gesprochen habe.\nΈτσι κάπως, νομίζω ότι μίλησα ξανά με τη Ραμόνα.\nSo anyhow, I think I spoke to Ramona again.\nEntonces, pienso que hablaré con Ramona de nuevo.\nDe toute façon, je pense que j'ai à nouveau parlé à Ramona.\nतो किसी भी तरह, मुझे लगता है कि मैं रमोना के लिए फिर से बात की थी।\nТак что, в любом случае, мне кажется, я снова говорил с Рамоной.\nKwa njia moja au nyingine niliongea na Ramona tena.\nอย่างไรก็แล้วแต่, ฉันคิดว่าฉันคุยกับราโมนาอีกครั้ง\nHer neyse, sanırım Ramona'yla bir kez daha konuştum.\nتو کسی بھی طرح،میں سوچتا ہوں کہ میں نے دوبارہ رامونا سے بات کی.\nDù sao đi nữa, tôi nghĩ tôi đã nói chuyện với Ramona một lần nữa.\n不管怎样，我想我又和雷蒙娜说话了。\nأشياء صغيرة كهذه أحدثت فرقًا كبيرًا في ما كنت أحاول القيام به.\nМалки неща като това направиха голяма разлика в това, което се опитвах да направя.\nSolche Kleinen dinge machten einen grossen Unterschied zu dem was ich versuchte zu tun.\nΤα μικρά πράγματα όπως αυτό έκαναν την μεγάλη διαφορά σε αυτό που προσπαθούσα να κάνω.\nLittle things like that made a big difference in what I was trying to do.\nPequeñas cosas como esta marcaron una gran diferencia en lo que estaba tratando de hacer.\nDes petites choses comme celles-là font une différence énorme dans ce que j'essaye de faire.\nऐसे ही छोटी सी बातें भने मेरा कर्म पर एक बड़ा अंतर बना दिया ।\nТакие, казалось бы, незначительные моменты оказывали сильное влияние на то, чем я пытался заниматься.\nMambo madogo yaliyoleta utofauti mkubwa kwa mambo niliyokuwa nikijaribu kuyafanya.\nสิ่ง ๆ เช่นนั้นสร้างความแตกต่างอย่างมากเกี่ยวกับสิ่งที่ฉันพยายามจะทำ\nBunun gibi küçük şeyler yapmaya çalıştığım şeyde büyük fark yarattı.\nاس طرح کی چھوٹی چیزیں جو میں کرنے کی کوشش کر رہا تھا اس میں ایک بڑا فرق بن گیا.\nNhững điều nhỏ nhặt như thế đã tạo nên sự khác biệt lớn trong những gì tôi đã cố gắng làm.\n“蝴蝶效应”正是我在尝试做的事情。\nحسناً هل هناك أحد هناك لمساعدتي.\nЕ, там няма никой, за да ми помогне.\nNun, da ist niemand da um mir zu helfen.\nΛοιπόν, δεν υπάρχει κανείς εκεί για να με βοηθήσει.\nWell, there's nobody there to help me.\nBueno, no hay nadie ahí para ayudarme.\nEh bien, il n'y a personne pour m'aider.\nखैर, मेरी मदद करने के लिए कोई नहीं है।\nВедь там мне помочь некому.\nNaam, hakuna mtu yeyote hapo atakayenisaidia.\nช่างดีจริงๆ, ไม่มีใครที่ที่จะช่วยฉัน\nAslında, orada bana yardım edecek hiç kimse yok.\nميری مدد کرنے کے لۂے کوۂی نہیں ہے۔\nVâng, không có ai ở đó để giúp tôi.\n没有人能帮我。\nلقد كان هو الضحية الوحيدة للأزمة الكوبية, وأه, قيصر أه, حصل على الصور وطار مباشرة إلى سلاح الجو في أندروز في واشنطن.\nТой беше единствената жертва през кубинската криза, и направи снимките и излетя директно за военновъздушната база Андрюс във Вашингтон.\nEr war das einzige Opfer der kubanischen Krise, und ähm, Kaiser ähm, er machte die Fotos und flug sofort zu der Andrews Air Force in Washington.\nΉταν η μοναδική απώλεια στην Κρίση της Κούβας και, ο Kaiser, πήρε τις φωτογραφίες και πέταξε κατευθείαν στο Αεροδρόμιο Andrews της Πολεμικής Αεροπορίας στην Ουάσινγκτον.\nHe was the only casualty of the Cuban Crisis, and uh, Kaiser uh, he got the pictures and flew straight to Andrews Air Force in Washington.\nFue la única baja de la Crisis de Cuba y, eh, Kaiser, eh, se hizo con las imágenes y voló directo a Andrews Air Force en Washington.\nIl était la seule victime de la crise cubaine, et euh, Kaiser euh, il a pris les photos et a volé directement jusqu'à Andrews Air Force à Washington.\nवह क्यूबाई संकट का एकमात्र हताहत था,संकट, और उह, kaiser,की  वह तस्वीरें लेकर वॉशिंगटन में सीधे एंड्रयूज वायु सेना के लिए उड़ान भरी।\nОн был единственной причиной Кубинского Кризиса и эм, Кайзер, эм, он получил фотографии и полетел прямо на авиабазу Эндрюс в Вашингтоне.\nAlikuwa mahuttuti katika mchafuko wa Cuban na Kaiser alizipata picha na akaelekea Andrews Air Force mjini Washington.\nเขาเป็นคนเดียวที่เสียชีวิตจากวิกฤติคิวบา และเอ่อ Kaiser เอ่อ เขามีรูปภาพและบินตรงไปที่ Andrews Air Force ในวอชิงตัน\nKüba Krizi'nin tek zayiatıydı ve Kaiser, resimleri aldı ve doğrudan Washington'daki Andrews Hava Kuvvetleri'ne uçtu.\nوہ کیوبا کے بحران کا واحد زخمی تھا، اور uh، کیسر uh، وہ تصاویر ملی اور واشنگٹن میں اینڈریو ایئر فورس کی طرف براہ راست اڑ گئے.\nAnh ta là nạn nhân duy nhất của cuộc khủng hoảng Cuba, và uh, anh ta đã chụp ảnh và bay thẳng đến Không quân Andrews ở Washington.\n他是古巴危机的唯一的受害者，呃，Kaiser，他得到了照片，直接飞往华盛顿的安德鲁斯空军基地。\nأمم, قالت, يا صغيري ، أنت لا تفهم الحياة بنفس طريقتي في فهم الحياة.\nИ тя каза, скъпа, ти не разбираш живота така, както аз го разбирам.\nÄhm, und sie sagte, sagte sie, sie sagte: Baby, sagte sie, du verstehst das Leben nicht so, wie ich das Leben verstehe.\nΜμμ, και είπε, είπε, είπε, μωρό μου, είπε, δεν καταλαβαίνεις τη ζωή με τον τρόπο που καταλαβαίνω εγώ τη ζωή.\nUm, and she said, she said, she said, Baby, she said, You don't understand about life the way I understand about life.\nHum y ella dijo... dijo... dijo... cariño, dijo, no entiendes la vida del modo que yo entiendo la vida.\nUm, et elle a dit, elle a dit, elle a dit, bébé, elle a dit, Vous ne comprenez pas  la vie de la façon dont je la comprends.\nउम, उसने कहा, उसने कहा, उसने कहा, बेबी, उसने कहा, जीवन के बारे में आप उस तरह नहीं समझते जैसा मैं समझती हूँ।\nХмм... и тут она говорит, она говорит, говорит: Милый, говорит она, - Ты не знаешь жизни - не знаешь её так, как я её знаю.\nAlisema, Mpenzi, huyaelewi maisha jinsi ninavyoyaelewa.\nอืม และเธอพูดว่า ที่รัก คุณไม่เข้าใจชีวิตเหมือนที่ฉันเข้าใจหรอก\nAh ve dedi ki Bebeğim hayatı benim anladığım şekilde anlamıyorsun.\nام، اور اس نے کہا، اس نے کہا، اس نے کہا، بچہ، اس نے کہا،آپ زندگی کے بارے میں سمجھتے ہیں جس طرح زندگی کے بارے میں نہیں سمجھتے.\nỪm, và cô ấy nói, cô ấy nói, cô ấy nói, Anh yêu, cô ấy nói, Anh không hiểu về cuộc sống theo cách em hiểu về cuộc sống.\n嗯，她说，她说，她说，宝贝，她说，你理解生命的方法跟我理解生命的方法不一样。\nوبعد ذلك حصلت عليه، وأنا مثل العظماء، ماذا عساي أن أفعل به؟\nТогава го разбирам и се чувствам страхотно, какво да правя с него?\nDann bekomme ich es und ich denke toll, was mache ich damit?\nΜετά το κατάφερα και νιώθω ωραία, Τι να κάνω με αυτό;\nThen I get it and I'm like great, What do I do with it?\nLuego lo pillo y me quedo como, genial. ¿Qué hago con ello?\nPuis je l'ai et je me dis : « bon super, qu'est-ce que je fais avec ? »\nफिर मुझे मिल जाती है और मै हु की, बढ़िया है , इसके साथ क्या करू मै अब ?।\nПотом я понимаю, что к чему, и такая: ну зашибись, и что мне с этим делать?\nKisha nikiipata na ninaona ni vyema, nitafanya nini na hilo?\nถ้างั้นฉันเข้าใจล่ะ และฉันก็ประมาณว่า เยี่ยมเลย ฉันจะต้องทำอะไรกับมัน?\nSonra onu alıyorum ve bir harika oluyorum, onunla ne yapacağım?\nپھر مجھے ملتا ہے اور میں سوچھتا ہوں کے اس کے ساتھ کیا کروں گا\nSau đó, tôi đã hiểu ra và tôi cảm thấy tuyệt, tôi phải làm gì với nó?\n然后我明白了，我的感觉很棒，我该怎么办呢？\nقامت وكالة الاستخبارات الأمريكية بتفريغ الفيلم، وأخذه إلى الأمم المتحدة في اليوم التالي.\nЦРУ изпразни филма и ги отведе в ООН на следващия ден.\nDie CIA entlud den, den Film, brachte sie am nächsten Tag zu den Vereinten Nationen.\nΗ CIA κυκλοφόρησε το φιλμ και το πήγε στα Ηνωμένα Έθνη την επόμενη μέρα.\nThe CIA unloaded the, the film, taken them to the United Nations the next day.\nLa CIA descargó la película y la llevó a las Naciones Unidas al día siguiente.\nLa CIA téléchargea le, le film, les emmena aux nations Unis le jour suivant.\nसी आई ए ने चित्र उतार दिए , अगले ही दिन उसे यूनाइटेड नेशंस ले गए ।\nЦРУ выдало плёнку, и на следующий день она была доставлена в ООН.\nCIA iliufunguanisha ile sinema na kuipeleka ka Muungana wa Kimataifa(UN) siku iliyofuata.\nCIA ขนภาพยนตร์ออกโดยการนำพวกมันไปยัง United Nations ในวันถัดไป\nCIA, filmi indirdi ve filmi ertesi gün Birleşmiş Milletlere götürdü.\nسی آئی اے نے اس فلم کو بے نقاب کیا، اگلے دن ان کو اقوام متحدہ میں لے گئے.\nCIA đã dỡ bỏ tấm phim, đưa chúng đến Liên Hiệp Quốc vào ngày hôm sau.\n中央情报局卸载了这部电影，第二天把他们带到了联合国。\nلقد أخبرته بالفعل، وحاولت أن أشرح له أني كنت محبطاً ولم يكن لدي جميع المعلومات التي أحتاجها.\nВече му казах, опитах се да му обясня, че съм разочарована, че не разполагах с цялата информация, от която имах нужда.\nIch habe ihm schon gesagt, ich habe versucht ihm zu erklären, dass ich enttäuscht war, weil ich nicht die ganze Information hatte die ich brauchte.\nΤου το είπα ήδη, προσπάθησα να του εξηγήσω ότι ήμουν απογοητευμένος που δεν είχα όλες τις πληροφορίες που χρειαζόμουν.\nI already told him, I tried to explain to him that I was frustrated I didn't have all the information I needed.\nYa le conté, intenté explicarle que estaba frustrado porque no tenía toda la información que necesitaba.\nJe lui ai déjà dit, j'ai essayé de lui expliquer que j'étais frustré de ne pas avoir toutes les informations dont j'avais besoin.\nमैंने उसे पहले ही कह दिया था, मैंने उसे समजने की कोशिश की की मै पक चूका था और मेरे पास आवश्यक सारे जानकारी भी नहीं थी ।\nЯ уже сказал ему, я попытался объяснить ему, что я был расстроен, потому что у меня не было всей необходимой информации.\nNishamwambia, Nilijaribu kumweleza kuwa nimekatishwa tamaa kwani sikuwa na habari yote niliyohitaji.\nฉันบอกเขาแล้ว ฉันพยายามอธิบายกับเขาแล้วว่าฉันกลัวฉันไม่มีข้อมูลที่ฉันต้องการ\nOna zaten söyledim, ihtiyacım olan tüm bilgiyi almadığım için hayal kırıklığı yaşadığımı ona açıklamaya çalıştım.\nمیں نے اس کو بتا دیا تھا، سمجھانے کی کوشش کی تھی کہ ساری معلومات نہ ہونے کی وجع سے مایوس تھی۔\nTôi đã nói với anh ấy, tôi đã cố gắng giải thích cho anh ấy rằng tôi đã thất vọng vì tôi không có đủ thông tin tôi cần.\n我已经告诉他了，我试图向他解释，我沮丧是我没有得到我需要的消息。\nفقط أمهلني دقيقة واحدة، إذا كنت تريد خوض ذلك، حسنًا، تفضل.\nСамо ми дайте една минута, ако искате да го прекратите, ще си отида.\nGib mir nur eine Minute, wenn du es abschneiden willst, ich würde, äh, gehen.\nΑπλά δώσε μου ένα λεπτό αν θέλετε να το κόψεις, θα, ε, πήγαινα.\nJust give me a minute if you want to cut it off, I'd, uh, go.\nDame solo un minuro, si quieres cortarlo, yo, hum, me voy.\nDonnez-moi juste une minute, si vous voulez le couper, euh, allez-y.\nमुझे एक मिनट दे दो अगर तुम उसे काटना चाहते हो, तो मैं जाऊँगा।\nЕсли хочешь, чтобы получилось эту штуку отрезать, дай мне минуту, тогда я, а..., иди.\nNipe dakila moja tu iwapo unataka niikate, ninaweza kuenda\nแค่ขอเวลาฉันแปปนึงหากคุณต้องการตัดมันออก ฉันจะ เอ่อ ไป\nKesmek istersen bana bir dakika ver, ben giderim.\nمجھے ایک منٹ دیجئے اگر آپ راستہ بند کرنا چاہیں گے تو میں چلا جاؤں گا۔\nChờ tôi 1 phút nếu bạn muốn cắt ngắn thời gian, tôi sẽ, đi nào.\n如果你想切断它，只需给我一分钟，我会，呃，离开。\nلذا ذهبت إلى واشنطن العاصمة، ولم أذهب مباشرة إلى..أه. التي قالوا لي عنها في أوامرهم.\nТака че отидох във Вашингтон и не отидох директно, където ми бяха казали в заповедите ми.\nAlso bin ich, also bin ich nach Washington D.C. gegangen und ich bin nicht direkt nach, ähm, dorthin gegangen, wohin ich laut meiner Anweisungen gehen sollte.\nΠήγα λοιπόν στην Ουάσινγκτον και δεν πήγα κατευθείαν, όπως με είχαν διατάξει\nSo I went to, I went to Washington D.C. and I didn't go directly to, uh, that, uh, they had told me to on my orders.\nAsí que fui a Washington D.C. y no fui directamente a donde me dijeron en mis órdenes.\nDonc je suis allé, je suis allé à Washington D.C. et je ne suis pas allé directement à, euh, cet, euh, où ils m'avaient ordonné d'aller.\nतो फिर मै वाशिंगटन डी सी चला गया और मै सीधे वहा नहीं गया , उन्होंने कहा था की हमारे आदेशों के लिए इंतज़ार करे ।\nИ я отправился в Вашингтон, но не пошел сразу в это, э-э-э, туда, куда меня направили по инструкции.\nNlienda katika, nlienda katika Washingon D.C na sikuenda moja kwa moja kwa, uh, lile, uh, walikuwa wameniambia katika maagizo yangu.\nคือฉันได้ไป, ฉันไปที่เมืองวอชิงตัน ดีซี และฉันไม่ได้ไปโดยตรงนะ, อ่า, ซึ่ง, อ่า, เขาต้องบอกฉันให้ฉันไปตามคำสั่ง\nBöylece, ben Washington D.C.'ye gittim ve doğrudan gitmedim, emirlerimde öyle yapmamı söylediler.\nTo mai gaya, mai Washington D.C. gaya aur mai direct nahi gaya, aur unho ne mjhe orders diye the.\nVà tôi đã đến, tôi đã đến Washington D.C. và tôi đã không đi thẳng đến, uh, chỗ đó, uh, nơi họ bảo tôi đến trong các đơn hàng.\n所以我去了，我去了哥伦比亚特区的华盛顿，我没有直接去，呃，呃，他们告诉过我，要按命令去做。\nلهذا السبب لم أتخرج من الكلية، لكنني لم أقرأ أبدًا أيًا من هذه الكتب التي كان من المفترض أن أقرأها.\nЗатова не завърших колежа, но никога не съм чел никоя от книгите, които трябваше.\nDas ist warum ich die Universität nicht beendet habe, weil ich nie die Bücher gelesen habe die ich lesen sollte\nΑυτός είναι ο λόγος για τον οποίο δεν αποφοίτησα από το κολέγιο, αλλά εγώ ποτέ, ποτέ δεν διάβασα κανένα από τα βιβλία που έπρεπε να διαβάσω.\nThat's why I didn't graduate college, but I never, I never read any of the books I was supposed to.\nPor eso no me gradúe en la universidad, pero nunca, nunca leí ninguno de los libros que tenía que leer.\nC'est pourquoi je n'ai pas obtenu mon diplôme, mais je n'ai jamais, jamais lu un seul livre que j'étais supposé lire.\nयही कारण है कि मैं कॉलेज स्नातक नहीं किया है, लेकिन मैंने उन किताबों को कभी नहीं पढ़ा जिन्हें मुझे पढ़ना था\nПо этой причине я не закончил колледж, но я никогда, ни за что не читал те книги, чтения которых от меня требовали.\nNdiyo sababu mimi sikuhitimu chuo kikuu, lakini sijawahi, sijawahi kusoma yoyote ya vitabu nilivyotakiwa.\nเหตุผลนั้นแหละที่ฉันเรียนไม่จบมหาลัย แต่ฉันไม่เคยอ่านหนังสื่อที่ฉันต้องอ่านเลยสักเล่ม\nBu yüzden üniversiteyi bitiremedim, ama asla, asla okumamgereken kitapları okumam.\nاور اس وجع سے میں کالج دگری مکمل نہیں کر سکا، لیکن میں نے کبھی بھی وہ کتابیں نہیں پڑھی جو لازمی تھیں\nĐó là lý do tại sao tôi không tốt nghiệp đại học, nhưng tôi chưa bao giờ, tôi chưa bao giờ đọc bất kỳ cuốn sách nào mà tôi cho là.\n这就是我大学没有毕业的原因，但是我读过的书没有哪本是我应该去读的。\nلقد كانت من أصحاب البشرة السمراء الفاتحة\nТя беше чернокожа със светло лице.\nSie war eine hellhäutige schwarze Person.\nΉταν ένα μαύρο άτομο με ανοιχτόχρωμο δέρμα.\nShe was a light-skinned black person.\nEra una persona negra de piel clara.\nElle était une personne noire à la peau claire.\nवे एक भूरे रंग की काली व्यक्ति थी ।\nОна была светлокожей негритянкой.\nAlikuwa mtoto wa mzungu na mwafrika.\nเธอเป็นผิวดำอ่อนๆ\nAçık tenli siyahi bir kişiydi.\nوہ ایک ہلکا پتلی سیاہ شخص تھی.\nCô ta là một người da đen.\n她是个浅肤色的黑人。\nعلى أي حال ، يذهب أبي ويصنع هذا الكوب الكبير من حليب الشوكولاته لي.\nКакто и да е, татко отива и прави една хубава голяма чаша шоколадово мляко за мен.\nWie auch immer, Papa geht und macht dieses schöne große Glas Schokoladenmilch für mich.\nΌπως και να 'χει. Ο μπαμπάς πηγαίνει και φτιάχνει αυτό το ωραίο μεγάλο ποτήρι σοκολατούχο γάλα για μένα.\nSo anyway, Dad goes and makes this nice big glass of chocolate milk for me.\nDe todos modos, papá se va y hace un buen vaso grande de leche con chocolate para mí.\nBon, quoi qu'il en soit, Papa va me préparer un bon grand verre de lait chocolaté.\nतो वैसे भी, पिताजी चले जाते हैं और मेरे लिए यह अच्छा व बड़ा गिलास चॉकलेट दूध बनाते हैं।\nТак что в любом случае папа пойдет и нальет мне большой стакан шоколадного молока.\nKwa hiyo, Baba huenda na kuniandalia glasi nzuri sana ya maziwa ya chokoleti .\nยังไงก็ตาม พ่อไปและทำนมช็อคโกแลตแก้วใหญ่ให้ฉัน\nYani, her neyse, Babam gidip bana bu güzel büyük bardakta çikolatalı sütü alıyor.\nتو ویسے ہی، والد چلا جاتا ہے اور میرے لئے چاکلیٹ دودھ کا یہ اچھا بڑا گلاس بنا دیتا ہے.\nVì vậy dù sao đi nữa, Bố đi và làm ly sữa sô cô la tuyệt vời này cho tôi.\n总而言之，爸爸去为我制作这一大杯好喝的巧克力牛奶。\nكان على شركة دوت شراء العقار والأشياء.\nDOT трябваше да купи имота и нещата.\nDer DOT musste das Grundstück und das ganze Zeug kaufen.\nΗ DOT έπρεπε να αγοράσει το ακίνητο και τα πράγματα.\nThe DOT had to buy the property and stuff.\nEl DOT tuvo que comprar el inmueble y tal.\nLe DOT devait acheter la propriété et tout ce qui va avec.\nडीओटी को जायदात  सामान खरीदना पड़ा।\nDOT необходимо было купить имущество и товары.\nLazima Dot angenunua mali na vitu vingine.\nDOT ต้องซื้อทรัพย์สินและสิ่งต่างๆ\nUlaştırma Bakanlığının mülk ve eşya satın alması gerekiyordu.\nڈاٹٹ کو پراپرٹی اور سامان خریدنا پڑا.\nDOT đã phải mua tài sản và các thứ khác.\nDOT必须购买那物业和东西。\nأمم ، وهكذا غادروا المدينة ، وهي لم تر أختها مرة أخرى ، ولم تر شقيقتها مرة أخرى.\nИ така, те просто напуснаха града и тя повече никога не видя сестра си отново, никога не видя сестра си отново.\nÄhm, sie haben also gerade die Stadt verlassen und sie, sie hat ihre Schwester nie wieder gesehen, ihre Schwester nie wieder gesehen.\nΧμ, και έτσι έφυγαν από την πόλη, και αυτή,  αυτή ποτέ δεν είδε ξανά την αδελφή της , δεν είδε την αδελφή της ποτέ ξανά\nUm, and so they just left town, and she, she never did see her sister again, never saw her sister again.\nMmm, así que simplemente se fueron de la ciudad, y ella, nunca volvió a ver a su hermana.\nEuh, et donc ils ont juste quitté la ville, et elle, elle n'a jamais revu sa sœur, jamais revu sa sœur.\nमं , और वो ऐसे ही नगर छोड़ दिए .और वो, उसने अपनी बहिन को फिर कभी नहीं मिलने की कोशिश नहीं की, कभी मिली भी नहीं ।\nГм, и поэтому они просто уехали из города, и она... она действительно больше никогда не видела свою сестру, больше никогда не видела свою сестру.\nUm, na hivyo tu waliondoka mji, na yeye, hakuwahi kumwona dada yake tena, hakuwahi kumwona dada yake tena.\nเอิ่ม และดังนั้นพวกเขาแค่ออกจากเมืองและเธอ เธอก็ไม่เคยเจอน้องสาวของเธออีกเลย ไม่เคยเจอน้องสาวของเธออีกครั้ง\nHmm, ve işte şehirden ayrıldılar, ve o, kız kardeşini bir daha asla göremedi, kız kardeşini bir daha asla göremedi.\nUm, aur isi trha wo sheher chor gaye, aur isne,apni behen ko kabhi nahi dekha, kabhi nahi dekha apni behan ko dobara.\nVà họ vừa mới rời thị trấn, và cô ấy, cô ấy không bao giờ gặp lại em gái mình nữa, không bao giờ gặp lại em gái mình nữa.\n嗯，他们刚刚离开小镇，她，她再也没有见到她的妹妹，再也没有见到她的妹妹。\nاستيقظت الجدة ، وانسحبت إلى أسفل الدرج من الشرفة ، وكانت تسير نحو الطريق ، ثم وقفت هناك.\nИ баба стана и някак си тръгна надолу по стълбите на верандата, и вървеше към пътя, и тогава просто застана там.\nAlso, Oma stand auf, lief etwa die Treppe hinunter zur Veranda hinaus und dann ist sie zur Straße gelaufen und stand dann einfach nur da.\nΈτσι η γιαγιά σηκώθηκε και κάπως κατέβηκε τα σκαλοπάτια από τη βεράντα και περπατούσε πλησιάζοντας το δρόμο και έπειτα απλά στάθηκε εκεί.\nSo Granny got up, and she kind of walked down the steps off the porch and she was walking up towards the road and she then just stood there.\nAsí que la abuela se levantó, y bajó los escalones del cobertizo, estaba caminando hacia la carretera y luego se quedó de pie allí.\nDonc Granny s'est levé et elle a descendu les escaliers de la véranda et elle avançait vers la route et elle s'est juste figée là.\nतो दादी उठ गई, और वह एक तरह से पोर्च बंद चरणों नीचे चलने लगी और वह सड़क की ओर चल रही थी और वह तो सिर्फ वहाँ खड़ी थी।\nТогда Бабушка поднялась и как-бы спустилась с веранды по ступенькам и направилась к дороге, а там встала и просто стояла.\nTena nyanyangu akaamka na akaanza kutembea chini ya ngazi za ukumbi akielekea barabarani. Alafu akasimama hapo bila kutembea.\nยายลุกขึ้นและเธอก็เดินลงบันไดออกจากระเบียงและเดินขึ้นไปบนถนนและเธอก็ยืนอยู่ที่นั่น\nBüyükanne kalktı ve sundurmanın önünden aşağı doğru yürüdü ve oraya doğru yürüyordu ve o da orada durdu.\nتو نانی اٹھی، اور پورچ کی سیڑیوں سے نیچے چلی گئی اور وہ سڑک کی طرف جا رہی تھی اور بس پھر وہ وہاں کھڑی ہو گئی.\nVì vậy, Granny đứng dậy, và cô ấy bước xuống bậc thềm ngoài hiên nhà và cô ấy đang bước lên đường và sau đó cô ấy chỉ đứng đó.\n奶奶站了起来，她走下门廊的台阶，走到马路边，然后站在那里。\nلكان قد ولد\nТой щеше да се е родил\nEr wäre geboren worden\nΘα είχε γεννηθεί\nHe would have been born\nÉl habría nacido.\nIl serait déjà né\nउसका जन्म हो गया होगा\nОн бы тогда родился.\nAngekuwa amezaliwa.\nเขาควรจะเกิดมา\nDoğmuş olurdu.\nوہ پیدا ہوا تھا\nĐáng nhẽ giờ nó phải chui ra rồi chứ\n他可能已经出生了\nالقصة التي سأتحدث عنها اليوم هي عن والدي والتنوعات الثقافية التي كان يتمتع بها عندما انتقل إلى أمريكا.\nИсторията, за която ще говоря днес, е за баща ми и културните различия, които имаше, когато той се мести в Америка.\nDie Geschichte die ich heute erzähle ist über meinen Vater und die Kulturunterschiede die er erlebte, als er nach Amerika zog.\nΗ ιστορία για την οποία θα μιλήσω σήμερα είναι για τον πατέρα μου και για τις πολιτισμικές διαφορές που είχε όταν μετακόμισε στην Αμερική.\nThe story I shall talk about today is about my father and the culture diversities he had when he moved to America.\nLa historia de la que hablaré hoy es sobre mi padre y las diversidades culturales que tuvo cuando se mudó a Estados Unidos.\nL'histoire dont je parlerai aujourd'hui concerne mon père et les diversités culturelles qu'il a eues lorsqu'il a déménagé en Amérique.\nआज जो कहानी आपको बताऊंगा , वह मेरे पिता और संस्कृति की विविधताओं के बारे में है, जब वह अमेरिका चले गए थे ।\nВ истории, которую я расскажу сегодня, говорится о моем отце и о культурных различиях, которые он ощущал, приехав в Америку.\nHadithi ntakayoisema leo ni kuhusu babayangu na tamaduni tofauti alizokumbana nazo alipoenda Marekani.\nเรื่องที่ฉันจะพูดถึงในวันนี้เกี่ยวกับพ่อของฉันและความแตกต่างทางวัฒนธรรมที่เขามีเมื่อเขาย้ายไปอเมริกา\nBugün hakkında konuşacağım hikaye babam ve Amerika'ya taşındığında sahip olduğu kültür çeşitliliğiyle ilgilidir.\nجس کہانی کے بارے میں میں آج گفتگو کروں گا وہ میرے والد اور ان کے اس ثقاوتی تنوع کے بارے میں ہے جب وہ امریکا میں آئے تھے.\nCâu chuyện tôi sẽ kể trong ngày hôm nay là về cha tôi và sự đa dạng văn hóa mà ông có được khi chuyển đến Mỹ.\n我今天要讲的故事是关于我的父亲和他搬到美国时的文化多样性。\nأنا وكأني أعرف كم قطعت من المسافات.\nАз съм като човек, който знае докъде е стигнал.\nIch bin wie, ich weiß, wie weit ich gekommen bin.\nΞέρω πόσο μακριά έχω πάει.\nI'm like, I know how far I've gotten.\nSoy como, sé lo lejos que he llegado.\nGenre, je sais comme j'ai avancé.\nमै यही सोच रहा था की कितना दूर आ चूका हु मै ।\nА я, такой, я и сам знаю, как далеко зашел.\nNiko kama, najua umbali nimefika.\nฉันก็ประมาณว่า ฉันรู้ว่าฉันได้อะไรบ้างที่ผ่านมา\nNe kadar uzağa gittiğimi biliyor gibiyim.\nمجھے پسند ہے، میں جانتا ہوں کہ میں کتنے دور تک پہنچ گیا ہوں.\nTôi giống như đã biết tôi đã đi được bao xa.\n我就像是，我知道我得到了多少。\nحسنًا، هل يمكنك سماعي؟\nДобре, чуваш ли ме?\nOK, kannst du mich hören?\nΕντάξει, μπορείς να με ακούσεις;\nOK, can you hear me?\nالسلام عليكم ورحمة الله وبركاته\nیہ ایک خوبصورت دن ہے\nانا اسمي محمد\nمیرا نام احمد ہے\nقال الطالب: نعم\nپاکستان زندہ باد\nالحمد لله رب العالمين\nخوش آمدید\nبسم الله الرحمن الرحيم\nتم کہاں جا رہے ہو؟\nΚαλημέρα σας\nΤο όνομά μου είναι Νίκος\nΑυτό είναι ένα βιβλίο\nΓειά σου Κώστα\nΕλλάδα είναι όμορφη\nΑθήνα, Σπάρτη, Κρήτη\nΠώς είσαι;\nΜου αρέσει ο καφές\nΧαίρετε!\nΣήμερα είναι Παρασκευή\nПривет, как дела?\nМосква — столица России\nЯ люблю читать книги\nСегодня холодно\nМоя мама учительница\nХорошо, спасибо\nДо свидания\nДоброе утро\nМоя фамилия Иванов\nПожалуйста, входите\n你好\n我爱你\n今天是星期五\n北京是中国的首都\n谢谢\n图书馆在这里\n这是我的朋友\n学习中文很有趣\n早上好\n再见\nBonjour, comment ça va?\nJe m'appelle Claire\nMerci beaucoup\n¿Dónde está la biblioteca?\n¡Buenos días!\nHola, me llamo Juan\nWie geht es dir?\nIch heiße Peter\nDas ist ein Buch\nGünaydın, nasılsın?\nBenim adım Ayşe\nTürkiye çok güzel\nXin chào, tôi là Nam\nTôi yêu Việt Nam\nChúc mừng năm mới\nHello, how are you?\nThis is my book\nI love programming\nThe cat is on the mat\nIt’s raining outside\nOpen the window please\nWhat time is it?\nSee you tomorrow\nGood morning everyone\nThe quick brown fox jumps over the lazy dog\nनमस्ते, आप कैसे हैं?\nमेरा नाम मोहन है\nभारत एक महान देश है\nयह किताब बहुत अच्छी है\nकल मैं स्कूल जाऊँगा\nمیں قرآن پڑھ رہا ہوں\nالقرآن الكريم\nدعاء کے بعد\nدعاء بعد الصلاة\nمحبت اور امن\nfor(int i=0; i<10; i++) {{ printf('Hello'); }}\n#include <stdio.h> // This program prints hello world\nprint('Bonjour le monde')\nconsole.log('¡Hola mundo!')\npublic static void main(String[] args)\nfunction greet() {{ return 'Hallo Welt'; }}\necho 'Merhaba Dünya'\nif(x > 5) {{ System.out.println('Xin chào'); }}\ncout << 'Привет мир';\nprintf('你好，世界');\n1234567890\n42\n!!! ??? ...\n@username #hashtag\n2025-09-08\nBonjour, today is sunny\nHola, mein Freund\nПривет, hello\nXin chào, merci beaucoup\nनमस्ते, good morning\nok\nyes\nbonjour\nhola\nnamaste\n"
  },
  {
    "path": "tinker_cookbook/exceptions.py",
    "content": "\"\"\"Centralized exception hierarchy for tinker-cookbook.\n\nAll custom exceptions inherit from :class:`TinkerCookbookError`, making it easy\nfor downstream consumers to catch *any* cookbook-specific error with a single\n``except TinkerCookbookError`` clause while still allowing fine-grained handling\nof specific error categories.\n\nThis module does **not** replace the Tinker SDK's own exception hierarchy\n(``tinker.TinkerError``, ``tinker.APIError``, etc.).  Those exceptions are\nraised by the SDK when communicating with the Tinker service; the exceptions\nhere cover errors that originate in the cookbook's own logic — configuration\nvalidation, data loading, rendering, weight management, and so on.\n\nTypical usage::\n\n    from tinker_cookbook.exceptions import ConfigurationError, DataError\n\n    if model_name not in KNOWN_MODELS:\n        raise ConfigurationError(f\"Unknown model: {model_name}\")\n\nAdding a new exception\n~~~~~~~~~~~~~~~~~~~~~~\n\n1. Subclass :class:`TinkerCookbookError` (or a category subclass like\n   :class:`DataError`).\n2. Also inherit from the stdlib exception it replaces (e.g. ``ValueError``,\n   ``RuntimeError``) so that existing ``except`` clauses keep working.\n3. Add it to :data:`__all__` below **and** to ``tinker_cookbook/__init__.py``.\n4. Keep exceptions picklable — do **not** add custom ``__init__`` parameters\n   without implementing ``__reduce__``.  Picklability is required for\n   ``multiprocessing`` and distributed task frameworks.\n\"\"\"\n\n__all__ = [\n    \"TinkerCookbookError\",\n    \"ConfigurationError\",\n    \"DataError\",\n    \"DataFormatError\",\n    \"DataValidationError\",\n    \"RendererError\",\n    \"TrainingError\",\n    \"CheckpointError\",\n    \"AllTrajectoriesFailedError\",\n    \"WeightsError\",\n    \"WeightsDownloadError\",\n    \"WeightsMergeError\",\n    \"SandboxError\",\n]\n\n\nclass TinkerCookbookError(Exception):\n    \"\"\"Base exception for all tinker-cookbook errors.\n\n    Catch this to handle any error raised by cookbook code (as opposed to\n    errors from the Tinker SDK or third-party libraries).\n    \"\"\"\n\n\n# ---------------------------------------------------------------------------\n# Configuration errors\n# ---------------------------------------------------------------------------\n\n\nclass ConfigurationError(TinkerCookbookError, ValueError):\n    \"\"\"A configuration parameter is invalid or missing.\n\n    Raised when user-supplied configuration (model names, hyperparameters,\n    renderer names, required fields, etc.) fails validation.  Inherits from\n    :class:`ValueError` for backward compatibility with code that already\n    catches ``ValueError`` for configuration problems.\n\n    Examples:\n        - Unknown model name\n        - Missing required config key (e.g. ``kl_reference_config``)\n        - Invalid hyperparameter combination\n    \"\"\"\n\n\n# ---------------------------------------------------------------------------\n# Data errors\n# ---------------------------------------------------------------------------\n\n\nclass DataError(TinkerCookbookError, ValueError):\n    \"\"\"An error related to training or evaluation data.\n\n    Base class for data-related errors.  Inherits from :class:`ValueError`\n    for backward compatibility.\n    \"\"\"\n\n\nclass DataFormatError(DataError):\n    \"\"\"Data is not in the expected format.\n\n    Raised when input data (JSONL files, HuggingFace datasets, conversation\n    dicts, etc.) is structurally malformed — e.g. a missing ``messages``\n    field in a JSONL line, or a conversation with too few tokens.\n    \"\"\"\n\n\nclass DataValidationError(DataError):\n    \"\"\"Data fails a semantic validation check.\n\n    Raised when data is structurally correct but violates a logical\n    constraint — e.g. streaming datasets cannot seek backward, or\n    there are not enough tokens for an input/target split.\n    \"\"\"\n\n\n# ---------------------------------------------------------------------------\n# Renderer errors\n# ---------------------------------------------------------------------------\n\n\nclass RendererError(TinkerCookbookError, ValueError):\n    \"\"\"An error related to renderer configuration or rendering.\n\n    Raised when a renderer cannot be found, messages cannot be rendered\n    into a model prompt, or a response cannot be parsed back into messages.\n    Inherits from :class:`ValueError` for backward compatibility.\n    \"\"\"\n\n\n# ---------------------------------------------------------------------------\n# Training errors\n# ---------------------------------------------------------------------------\n\n\nclass TrainingError(TinkerCookbookError, RuntimeError):\n    \"\"\"An error during a training loop.\n\n    Base class for errors that occur while executing SL, RL, DPO, or\n    distillation training loops.  Inherits from :class:`RuntimeError`\n    for backward compatibility.\n    \"\"\"\n\n\nclass CheckpointError(TrainingError):\n    \"\"\"An error related to saving, loading, or resuming checkpoints.\n\n    Raised when a checkpoint file is missing, corrupted, or when the\n    save/load operation fails.\n    \"\"\"\n\n\nclass AllTrajectoriesFailedError(TrainingError):\n    \"\"\"All trajectories in a rollout group failed.\n\n    Caught internally by the rollout pipeline to skip the affected group\n    rather than crash the training run.\n    \"\"\"\n\n\n# ---------------------------------------------------------------------------\n# Weights errors\n# ---------------------------------------------------------------------------\n\n\nclass WeightsError(TinkerCookbookError):\n    \"\"\"An error related to weight download, merge, or export.\n\n    Grouping base for weights-related errors.  Does not inherit from a\n    stdlib exception — use the specific subclasses which each carry\n    exactly one stdlib base appropriate to their failure mode.\n    \"\"\"\n\n\nclass WeightsDownloadError(WeightsError, RuntimeError):\n    \"\"\"Failed to download weights from Tinker storage.\n\n    Raised when the Tinker service cannot be reached, the checkpoint\n    path is invalid, or the download archive is corrupt.  Inherits from\n    :class:`RuntimeError` because these are operational failures.\n    \"\"\"\n\n\nclass WeightsMergeError(WeightsError, ValueError):\n    \"\"\"Failed to merge LoRA adapter weights into a base model.\n\n    Raised when adapter weights are incompatible with the base model\n    (shape mismatches, missing keys, etc.).  Inherits from\n    :class:`ValueError` because merge errors are validation failures\n    (wrong shapes, missing config keys).\n    \"\"\"\n\n\n# ---------------------------------------------------------------------------\n# Sandbox errors\n# ---------------------------------------------------------------------------\n\n\nclass SandboxError(TinkerCookbookError, RuntimeError):\n    \"\"\"An error related to code-execution sandboxes.\n\n    Base class for sandbox errors — e.g. sandbox termination, timeouts,\n    or unexpected sandbox failures.\n    \"\"\"\n"
  },
  {
    "path": "tinker_cookbook/exceptions_test.py",
    "content": "\"\"\"Tests for the exception hierarchy in tinker_cookbook.exceptions.\n\nVerifies inheritance contracts so that future changes don't accidentally\nbreak backward compatibility (stdlib bases) or the catch-all\nTinkerCookbookError base.\n\"\"\"\n\nimport pickle\n\nimport pytest\n\nfrom tinker_cookbook.exceptions import (\n    CheckpointError,\n    ConfigurationError,\n    DataError,\n    DataFormatError,\n    DataValidationError,\n    RendererError,\n    SandboxError,\n    TinkerCookbookError,\n    TrainingError,\n    WeightsDownloadError,\n    WeightsError,\n    WeightsMergeError,\n)\n\n# ---------------------------------------------------------------------------\n# Every custom exception must be a TinkerCookbookError\n# ---------------------------------------------------------------------------\n\nALL_EXCEPTIONS = [\n    ConfigurationError,\n    DataError,\n    DataFormatError,\n    DataValidationError,\n    RendererError,\n    TrainingError,\n    CheckpointError,\n    WeightsError,\n    WeightsDownloadError,\n    WeightsMergeError,\n    SandboxError,\n]\n\n\n@pytest.mark.parametrize(\"exc_cls\", ALL_EXCEPTIONS, ids=lambda c: c.__name__)\ndef test_all_exceptions_are_tinker_cookbook_errors(exc_cls: type[Exception]):\n    assert issubclass(exc_cls, TinkerCookbookError)\n    assert isinstance(exc_cls(\"test\"), TinkerCookbookError)\n\n\n# ---------------------------------------------------------------------------\n# Backward-compatible stdlib bases\n# ---------------------------------------------------------------------------\n\nSTDLIB_COMPAT = [\n    (ConfigurationError, ValueError),\n    (DataError, ValueError),\n    (DataFormatError, ValueError),\n    (DataValidationError, ValueError),\n    (RendererError, ValueError),\n    (TrainingError, RuntimeError),\n    (CheckpointError, RuntimeError),\n    (WeightsDownloadError, RuntimeError),\n    (WeightsMergeError, ValueError),\n    (SandboxError, RuntimeError),\n]\n\n\n@pytest.mark.parametrize(\n    \"exc_cls, stdlib_base\",\n    STDLIB_COMPAT,\n    ids=lambda x: x.__name__ if isinstance(x, type) else \"\",\n)\ndef test_stdlib_backward_compatibility(exc_cls: type[Exception], stdlib_base: type[Exception]):\n    \"\"\"Existing `except ValueError:` / `except RuntimeError:` handlers must keep working.\"\"\"\n    assert issubclass(exc_cls, stdlib_base)\n    assert isinstance(exc_cls(\"test\"), stdlib_base)\n\n\n# ---------------------------------------------------------------------------\n# Subclass relationships\n# ---------------------------------------------------------------------------\n\n\ndef test_data_subtypes():\n    assert issubclass(DataFormatError, DataError)\n    assert issubclass(DataValidationError, DataError)\n\n\ndef test_training_subtypes():\n    assert issubclass(CheckpointError, TrainingError)\n\n\ndef test_weights_subtypes():\n    assert issubclass(WeightsDownloadError, WeightsError)\n    assert issubclass(WeightsMergeError, WeightsError)\n\n\n# ---------------------------------------------------------------------------\n# SandboxTerminatedError integration\n# ---------------------------------------------------------------------------\n\n\ndef test_sandbox_terminated_error_is_sandbox_error():\n    from tinker_cookbook.sandbox.sandbox_interface import SandboxTerminatedError\n\n    assert issubclass(SandboxTerminatedError, SandboxError)\n    assert issubclass(SandboxTerminatedError, TinkerCookbookError)\n    assert issubclass(SandboxTerminatedError, RuntimeError)\n\n\n# ---------------------------------------------------------------------------\n# __all__ is in sync\n# ---------------------------------------------------------------------------\n\n\n# ---------------------------------------------------------------------------\n# Picklability (required for multiprocessing / distributed tasks)\n# ---------------------------------------------------------------------------\n\n\n@pytest.mark.parametrize(\"exc_cls\", ALL_EXCEPTIONS, ids=lambda c: c.__name__)\ndef test_exceptions_are_picklable(exc_cls: type[Exception]):\n    \"\"\"All exceptions must survive pickle round-trip for multiprocessing.\"\"\"\n    original = exc_cls(\"test message\")\n    roundtripped = pickle.loads(pickle.dumps(original))\n    assert type(roundtripped) is type(original)\n    assert str(roundtripped) == str(original)\n    assert roundtripped.args == original.args\n\n\n# ---------------------------------------------------------------------------\n# __all__ is in sync\n# ---------------------------------------------------------------------------\n\n\ndef test_exceptions_all_is_complete():\n    \"\"\"__all__ in exceptions.py must list every public exception class.\"\"\"\n    import tinker_cookbook.exceptions as mod\n\n    public_exc_classes = {\n        name\n        for name, obj in vars(mod).items()\n        if isinstance(obj, type)\n        and issubclass(obj, TinkerCookbookError)\n        and not name.startswith(\"_\")\n    }\n    assert public_exc_classes == set(mod.__all__)\n"
  },
  {
    "path": "tinker_cookbook/hyperparam_utils.py",
    "content": "\"\"\"\nUtilities for guessing good hyperparameters for fine-tuning.\n\"\"\"\n\nimport json\nimport math\nimport struct\n\nimport huggingface_hub\nimport numpy as np\nfrom transformers import AutoConfig\n\nfrom tinker_cookbook.exceptions import ConfigurationError\nfrom tinker_cookbook.utils.misc_utils import not_none\n\n\ndef _list_param_shapes_from_safetensors_remote(\n    repo_id: str,\n    revision: str = \"main\",\n    token: str | None = None,\n) -> dict[str, tuple[int, ...]]:\n    \"\"\"\n    Returns {param_name: shape_tuple} by reading ONLY the safetensors header(s)\n    over HTTP (ranged requests). No full file download.\n    \"\"\"\n    fs = huggingface_hub.HfFileSystem(token=token)\n    info = huggingface_hub.model_info(repo_id, revision=revision, token=token)\n\n    # find all .safetensors files (handles sharded checkpoints)\n    st_files = [\n        s.rfilename for s in not_none(info.siblings) if s.rfilename.endswith(\".safetensors\")\n    ]\n    if not st_files:\n        raise FileNotFoundError(\"No .safetensors files found in this repo.\")\n\n    shapes: dict[str, tuple[int, ...]] = {}\n\n    for fname in st_files:\n        # Open remote file via fsspec; this performs HTTP range reads under the hood\n        path = f\"{repo_id}@{revision}/{fname}\"  # HfFileSystem path format\n        with fs.open(path, \"rb\") as f:\n            # safetensors spec:\n            # [0:8] = little-endian u64 header_len\n            # [8:8+header_len] = UTF-8 JSON header\n            header_len_bytes = f.read(8)\n            assert isinstance(header_len_bytes, bytes)\n            if len(header_len_bytes) < 8:\n                raise OSError(f\"File too small or not safetensors: {fname}\")\n            (header_len,) = struct.unpack(\"<Q\", header_len_bytes)\n\n            header_bytes = f.read(header_len)\n            assert isinstance(header_bytes, bytes)\n            if len(header_bytes) < header_len:\n                raise OSError(f\"Incomplete header read for {fname}\")\n\n            header = json.loads(header_bytes.decode(\"utf-8\"))\n            # header maps tensor_name -> { \"dtype\": \"...\", \"shape\": [...], \"data_offsets\": [start, end] }\n            for name, meta in header.items():\n                if name == \"__metadata__\":  # optional global metadata block\n                    continue\n                shapes[name] = tuple(meta[\"shape\"])\n\n    return shapes\n\n\ndef get_lora_lr_over_full_finetune_lr(model_name: str, lora_alpha: int = 32) -> float:\n    \"\"\"\n    Return the factor that you should scale the full fine-tuning learning rate by to get the equivalent LoRA learning rate.\n    Previously we had a more complicated formula, but the factor of 10 was more accurate empirically.\n    See Lora Without Regret (https://thinkingmachines.ai/blog/lora/) for more details.\n    \"\"\"\n    return 10.0\n\n\ndef _get_hidden_size(model_name: str) -> int:\n    if \"meta-llama/Llama-3\" in model_name:\n        # Bypass HF_TOKEN requirement for Llama-3 models\n        return {\n            \"meta-llama/Llama-3.2-1B\": 2048,\n            \"meta-llama/Llama-3.2-1B-Instruct\": 2048,\n            \"meta-llama/Llama-3.2-3B\": 3072,\n            \"meta-llama/Llama-3.2-3B-Instruct\": 3072,\n            \"meta-llama/Llama-3.1-8B\": 4096,\n            \"meta-llama/Llama-3.1-8B-Instruct\": 4096,\n            \"meta-llama/Llama-3.1-70B\": 8192,\n            \"meta-llama/Llama-3.3-70B-Instruct\": 8192,\n        }[model_name]\n\n    if model_name in (\n        \"deepseek-ai/DeepSeek-V3.1\",\n        \"deepseek-ai/DeepSeek-V3.1-Base\",\n        \"moonshotai/Kimi-K2-Thinking\",\n    ):\n        return 7168\n\n    config = AutoConfig.from_pretrained(model_name)\n    return config.hidden_size\n\n\ndef get_lora_param_count(\n    model_name: str,\n    lora_rank: int = 32,\n    detailed: bool = False,\n    include_experts: bool = True,\n    shared_expert_outer_loras: bool = True,\n) -> int | dict[str, int]:\n    \"\"\"\n    Get the number of parameters in the LoRA adapter.\n    \"\"\"\n\n    dim_sum = 0\n    dim_sum_experts = 0\n    ignore = [\"gate\", \"embed_tokens\", \"q_b_proj\", \"kv_b_proj\"]\n    if not include_experts:\n        ignore.append(\"experts\")\n\n    for name, shape in _list_param_shapes_from_safetensors_remote(model_name).items():\n        if (\n            len(shape) == 2\n            and name.endswith(\".weight\")\n            and not any(v in name.split(\".\") for v in ignore)\n        ):\n            parts = name.split(\".\")\n            if \"experts\" not in parts or not shared_expert_outer_loras:\n                dim_sum += shape[0] + shape[1]\n            else:\n                # For expert shared outer_loras, we only count the outer dims once, since they are shared across experts\n                expert_idx = int(parts[parts.index(\"experts\") + 1])\n                weight_name = parts[parts.index(\"experts\") + 2]\n                assert weight_name in [\"gate_proj\", \"down_proj\", \"up_proj\"], (\n                    f\"Unexpected expert weight name: {weight_name}\"\n                )\n                intermediate_dim = shape[1] if weight_name == \"down_proj\" else shape[0]\n                outer_dim = shape[0] if weight_name == \"down_proj\" else shape[1]\n\n                dim_sum_experts += intermediate_dim\n                if expert_idx == 0:\n                    dim_sum_experts += outer_dim\n\n    non_expert_params = lora_rank * dim_sum\n    expert_params = lora_rank * dim_sum_experts\n\n    return (\n        (expert_params + non_expert_params)\n        if not detailed\n        else {\n            \"expert_params\": expert_params,\n            \"non_expert_params\": non_expert_params,\n            \"total_params\": expert_params + non_expert_params,\n        }\n    )\n\n\ndef get_lr(model_name: str, is_lora: bool = True) -> float:\n    base_lr = 5e-05\n    lora_multiplier = 10.0\n\n    lr = base_lr * lora_multiplier if is_lora else base_lr\n    if \"llama\" in model_name.lower():\n        exponent_model = 0.781\n    elif \"qwen\" in model_name.lower():\n        exponent_model = 0.0775\n    elif model_name in (\n        \"deepseek-ai/DeepSeek-V3.1\",\n        \"deepseek-ai/DeepSeek-V3.1-Base\",\n        \"openai/gpt-oss-20b\",\n        \"openai/gpt-oss-120b\",\n        \"moonshotai/Kimi-K2-Thinking\",\n        \"moonshotai/Kimi-K2.5\",\n        \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n        \"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16\",\n    ):\n        raise NotImplementedError(\n            f\"Learning rate formula for {model_name} is not yet calibrated. \"\n            \"Please specify a learning rate manually.\"\n        )\n    else:\n        raise ConfigurationError(f\"Unknown model: {model_name}\")\n    # TODO: sweep to determine LR multipliers for other models\n    lr = lr * (2000 / _get_hidden_size(model_name)) ** exponent_model\n    return lr\n\n\ndef get_full_finetune_param_count(model_name: str) -> float:\n    count = 0\n    for _name, shape in _list_param_shapes_from_safetensors_remote(model_name).items():\n        count += np.prod(shape)\n    return float(count)\n\n\ndef get_full_finetune_lr_multiplier(model_name: str):\n    return 1.0 / math.sqrt(get_full_finetune_param_count(model_name))\n\n\ndef get_lora_lr_multiplier(model_name: str):\n    \"\"\"\n    Get a model-specific mutliplier for the LR, when training with LoRA.\n    Given two models A and B, and learning rate LR_A that's known to be optimal for A,\n    we can guess an optimal learning rate for B as\n    LR_B = LR_A * get_lora_lr_multiplier(B) / get_lora_lr_multiplier(A)\n    \"\"\"\n    return get_full_finetune_lr_multiplier(model_name) * get_lora_lr_over_full_finetune_lr(\n        model_name\n    )\n"
  },
  {
    "path": "tinker_cookbook/image_processing_utils.py",
    "content": "\"\"\"\nUtilities for working with image processors. Create new types to avoid needing to import AutoImageProcessor and BaseImageProcessor.\n\n\nAvoid importing AutoImageProcessor and BaseImageProcessor until runtime, because they're slow imports.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom functools import cache\nfrom typing import TYPE_CHECKING, Any, TypeAlias\n\nfrom PIL import Image\n\nif TYPE_CHECKING:\n    # this import takes a few seconds, so avoid it on the module import when possible\n    from transformers.image_processing_utils import BaseImageProcessor\n\n    ImageProcessor: TypeAlias = BaseImageProcessor\nelse:\n    # make it importable from other files as a type in runtime\n    ImageProcessor: TypeAlias = Any\n\n\n@cache\ndef get_image_processor(model_name: str) -> ImageProcessor:\n    model_name = model_name.split(\":\")[0]\n\n    from transformers.models.auto.image_processing_auto import AutoImageProcessor\n\n    kwargs: dict[str, Any] = {}\n    if os.environ.get(\"HF_TRUST_REMOTE_CODE\", \"\").lower() in (\"1\", \"true\", \"yes\"):\n        kwargs[\"trust_remote_code\"] = True\n\n    if model_name == \"moonshotai/Kimi-K2.5\":\n        kwargs[\"trust_remote_code\"] = True\n        kwargs[\"revision\"] = \"3367c8d1c68584429fab7faf845a32d5195b6ac1\"\n\n    processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True, **kwargs)\n    return processor\n\n\ndef resize_image(image: Image.Image, max_size: int) -> Image.Image:\n    \"\"\"\n    Resize an image so that its longest side is at most max_size pixels.\n\n    Preserves aspect ratio and uses LANCZOS resampling for quality.\n    Returns the original image if it's already smaller than max_size.\n    \"\"\"\n\n    width, height = image.size\n    if max(width, height) <= max_size:\n        return image\n\n    if width > height:\n        new_width = max_size\n        new_height = int(height * max_size / width)\n    else:\n        new_height = max_size\n        new_width = int(width * max_size / height)\n\n    return image.resize((new_width, new_height), Image.Resampling.LANCZOS)\n"
  },
  {
    "path": "tinker_cookbook/image_processing_utils_test.py",
    "content": "from unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom tinker_cookbook.image_processing_utils import get_image_processor\n\n\n@pytest.fixture(autouse=True)\ndef _clear_cache() -> None:\n    \"\"\"Clear the lru_cache between tests so env var changes take effect.\"\"\"\n    get_image_processor.cache_clear()\n\n\n@patch(\"transformers.models.auto.image_processing_auto.AutoImageProcessor\")\ndef test_kimi_k25_trusts_remote_code_without_env(\n    mock_auto: MagicMock, monkeypatch: pytest.MonkeyPatch\n) -> None:\n    \"\"\"Hardcoded Kimi K2.5 should pass trust_remote_code=True without the env var.\"\"\"\n    monkeypatch.delenv(\"HF_TRUST_REMOTE_CODE\", raising=False)\n    get_image_processor(\"moonshotai/Kimi-K2.5\")\n    mock_auto.from_pretrained.assert_called_once_with(\n        \"moonshotai/Kimi-K2.5\",\n        use_fast=True,\n        trust_remote_code=True,\n        revision=\"3367c8d1c68584429fab7faf845a32d5195b6ac1\",\n    )\n\n\n@patch(\"transformers.models.auto.image_processing_auto.AutoImageProcessor\")\ndef test_no_trust_remote_code_by_default(\n    mock_auto: MagicMock, monkeypatch: pytest.MonkeyPatch\n) -> None:\n    \"\"\"Without env var, generic models should NOT get trust_remote_code.\"\"\"\n    monkeypatch.delenv(\"HF_TRUST_REMOTE_CODE\", raising=False)\n    get_image_processor(\"some-org/some-model\")\n    mock_auto.from_pretrained.assert_called_once_with(\n        \"some-org/some-model\",\n        use_fast=True,\n    )\n\n\n@pytest.mark.parametrize(\"env_value\", [\"1\", \"true\", \"TRUE\", \"yes\"])\n@patch(\"transformers.models.auto.image_processing_auto.AutoImageProcessor\")\ndef test_env_var_enables_trust_remote_code(\n    mock_auto: MagicMock, monkeypatch: pytest.MonkeyPatch, env_value: str\n) -> None:\n    \"\"\"HF_TRUST_REMOTE_CODE env var should enable trust_remote_code for any model.\"\"\"\n    monkeypatch.setenv(\"HF_TRUST_REMOTE_CODE\", env_value)\n    get_image_processor(\"some-org/some-model\")\n    mock_auto.from_pretrained.assert_called_once_with(\n        \"some-org/some-model\",\n        use_fast=True,\n        trust_remote_code=True,\n    )\n"
  },
  {
    "path": "tinker_cookbook/model_info.py",
    "content": "\"\"\"\nThis module associates model names with metadata, which helps  training code choose good defaults.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom dataclasses import dataclass\nfrom functools import cache\n\nfrom tinker_cookbook.exceptions import ConfigurationError\n\nlogger = logging.getLogger(__name__)\n\n# Common renderer tuples, defined once to reduce repetition.\n# Tuples (not lists) because these are shared across ModelAttributes instances\n# — a mutable list would risk silent cross-model corruption if mutated.\n_LLAMA3 = (\"llama3\",)\n_ROLE_COLON = (\"role_colon\",)\n_QWEN3 = (\"qwen3\", \"qwen3_disable_thinking\")\n_QWEN3_INSTRUCT = (\"qwen3_instruct\",)\n_QWEN3_VL = (\"qwen3_vl\",)\n_QWEN3_VL_INSTRUCT = (\"qwen3_vl_instruct\",)\n_QWEN3_5 = (\"qwen3_5\", \"qwen3_5_disable_thinking\")\n_DEEPSEEKV3 = (\"deepseekv3\", \"deepseekv3_thinking\")\n_GPT_OSS = (\"gpt_oss_no_sysprompt\", \"gpt_oss_medium_reasoning\")\n_KIMI_K2 = (\"kimi_k2\",)\n_KIMI_K25 = (\"kimi_k25\", \"kimi_k25_disable_thinking\")\n_NEMOTRON3 = (\"nemotron3\", \"nemotron3_disable_thinking\")\n\n\n@dataclass\nclass ModelAttributes:\n    organization: str  # meta-llama, Qwen, etc.\n    version_str: str  # just the version number e.g. \"3.1\", \"2.5\"\n    size_str: str  # size of the model e.g. \"8B\", \"72B\", \"1.5B\"\n    is_chat: bool  # is chat/instruct model\n    recommended_renderers: tuple[str, ...]  # first entry is the most recommended\n    is_vl: bool = False  # is vision-language model\n\n\n@cache\ndef get_llama_info() -> dict[str, ModelAttributes]:\n    org = \"meta-llama\"\n    return {\n        \"Llama-3.2-1B-Instruct\": ModelAttributes(org, \"3.2\", \"1B\", True, _LLAMA3),\n        \"Llama-3.2-3B-Instruct\": ModelAttributes(org, \"3.2\", \"3B\", True, _LLAMA3),\n        \"Llama-3.1-8B-Instruct\": ModelAttributes(org, \"3.1\", \"8B\", True, _LLAMA3),\n        \"Llama-3.2-1B\": ModelAttributes(org, \"3.2\", \"1B\", False, _ROLE_COLON),\n        \"Llama-3.2-3B\": ModelAttributes(org, \"3.2\", \"3B\", False, _ROLE_COLON),\n        \"Llama-3.1-8B\": ModelAttributes(org, \"3.1\", \"8B\", False, _ROLE_COLON),\n        \"Llama-3.1-70B\": ModelAttributes(org, \"3.1\", \"70B\", False, _ROLE_COLON),\n        \"Llama-3.3-70B-Instruct\": ModelAttributes(org, \"3.3\", \"70B\", True, _LLAMA3),\n    }\n\n\n@cache\ndef get_qwen_info() -> dict[str, ModelAttributes]:\n    org = \"Qwen\"\n    return {\n        \"Qwen3-VL-30B-A3B-Instruct\": ModelAttributes(\n            org, \"3\", \"30B-A3B\", True, _QWEN3_VL_INSTRUCT, is_vl=True\n        ),\n        \"Qwen3-VL-235B-A22B-Instruct\": ModelAttributes(\n            org, \"3\", \"235B-A22B\", True, _QWEN3_VL_INSTRUCT, is_vl=True\n        ),\n        \"Qwen3-4B-Base\": ModelAttributes(org, \"3\", \"4B\", False, _ROLE_COLON),\n        \"Qwen3-8B-Base\": ModelAttributes(org, \"3\", \"8B\", False, _ROLE_COLON),\n        \"Qwen3-14B-Base\": ModelAttributes(org, \"3\", \"14B\", False, _ROLE_COLON),\n        \"Qwen3-30B-A3B-Base\": ModelAttributes(org, \"3\", \"30B-A3B\", False, _ROLE_COLON),\n        \"Qwen3-0.6B\": ModelAttributes(org, \"3\", \"0.6B\", True, _QWEN3),\n        \"Qwen3-1.7B\": ModelAttributes(org, \"3\", \"1.7B\", True, _QWEN3),\n        \"Qwen3-4B\": ModelAttributes(org, \"3\", \"4B\", True, _QWEN3),\n        \"Qwen3-8B\": ModelAttributes(org, \"3\", \"8B\", True, _QWEN3),\n        \"Qwen3-14B\": ModelAttributes(org, \"3\", \"14B\", True, _QWEN3),\n        \"Qwen3-32B\": ModelAttributes(org, \"3\", \"32B\", True, _QWEN3),\n        \"Qwen3-30B-A3B\": ModelAttributes(org, \"3\", \"30B-A3B\", True, _QWEN3),\n        \"Qwen3-4B-Instruct-2507\": ModelAttributes(org, \"3\", \"4B\", True, _QWEN3_INSTRUCT),\n        \"Qwen3-30B-A3B-Instruct-2507\": ModelAttributes(org, \"3\", \"30B-A3B\", True, _QWEN3_INSTRUCT),\n        \"Qwen3-235B-A22B-Instruct-2507\": ModelAttributes(\n            org, \"3\", \"235B-A22B\", True, _QWEN3_INSTRUCT\n        ),\n        \"Qwen3.5-4B\": ModelAttributes(org, \"3.5\", \"4B\", True, _QWEN3_5, is_vl=True),\n        \"Qwen3.5-27B\": ModelAttributes(org, \"3.5\", \"27B\", True, _QWEN3_5, is_vl=True),\n        \"Qwen3.5-35B-A3B\": ModelAttributes(org, \"3.5\", \"35B-A3B\", True, _QWEN3_5, is_vl=True),\n        \"Qwen3.5-397B-A17B\": ModelAttributes(org, \"3.5\", \"397B-A17B\", True, _QWEN3_5, is_vl=True),\n    }\n\n\n@cache\ndef get_deepseek_info() -> dict[str, ModelAttributes]:\n    org = \"deepseek-ai\"\n    return {\n        \"DeepSeek-V3.1\": ModelAttributes(org, \"3\", \"671B-A37B\", True, _DEEPSEEKV3),\n        \"DeepSeek-V3.1-Base\": ModelAttributes(org, \"3\", \"671B-A37B\", False, _ROLE_COLON),\n    }\n\n\n@cache\ndef get_gpt_oss_info() -> dict[str, ModelAttributes]:\n    org = \"openai\"\n    return {\n        \"gpt-oss-20b\": ModelAttributes(org, \"1\", \"21B-A3.6B\", True, _GPT_OSS),\n        \"gpt-oss-120b\": ModelAttributes(org, \"1\", \"117B-A5.1B\", True, _GPT_OSS),\n    }\n\n\n@cache\ndef get_moonshot_info() -> dict[str, ModelAttributes]:\n    org = \"moonshotai\"\n    return {\n        \"Kimi-K2-Thinking\": ModelAttributes(org, \"K2\", \"1T-A32B\", True, _KIMI_K2),\n        \"Kimi-K2.5\": ModelAttributes(org, \"K2.5\", \"1T-A32B\", True, _KIMI_K25, is_vl=True),\n    }\n\n\n@cache\ndef get_nvidia_info() -> dict[str, ModelAttributes]:\n    org = \"nvidia\"\n    return {\n        \"NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\": ModelAttributes(\n            org, \"3\", \"30B-A3B\", True, _NEMOTRON3\n        ),\n        \"NVIDIA-Nemotron-3-Super-120B-A12B-BF16\": ModelAttributes(\n            org, \"3\", \"120B-A12B\", True, _NEMOTRON3\n        ),\n    }\n\n\ndef get_model_attributes(model_name: str) -> ModelAttributes:\n    model_name = model_name.split(\":\")[0]\n    org, model_version_full = model_name.split(\"/\")\n    model_version_full = model_version_full.split(\":\")[0]\n    if org == \"meta-llama\":\n        return get_llama_info()[model_version_full]\n    elif org == \"Qwen\":\n        return get_qwen_info()[model_version_full]\n    elif org == \"deepseek-ai\":\n        return get_deepseek_info()[model_version_full]\n    elif org == \"openai\":\n        return get_gpt_oss_info()[model_version_full]\n    elif org == \"moonshotai\":\n        return get_moonshot_info()[model_version_full]\n    elif org == \"nvidia\":\n        return get_nvidia_info()[model_version_full]\n    else:\n        raise ConfigurationError(f\"Unknown model: {model_name}\")\n\n\ndef get_recommended_renderer_names(model_name: str) -> list[str]:\n    \"\"\"\n    Return a list of renderers that are designed for the model.\n    Used so we can emit a warning if you use a non-recommended renderer.\n    The first result is the most recommended renderer for the model.\n    \"\"\"\n    return list(get_model_attributes(model_name).recommended_renderers)\n\n\ndef get_recommended_renderer_name(model_name: str) -> str:\n    \"\"\"\n    Return the most recommended renderer for the model.\n    \"\"\"\n    return get_recommended_renderer_names(model_name)[0]\n\n\ndef warn_if_renderer_not_recommended(model_name: str, renderer_name: str | None) -> None:\n    \"\"\"\n    Log a warning if ``renderer_name`` is not in the recommended list for ``model_name``.\n\n    Silently returns if ``renderer_name`` is None (caller is using the default) or if\n    ``model_name`` is not in the model registry.\n    \"\"\"\n    if renderer_name is None:\n        return\n    try:\n        recommended = get_recommended_renderer_names(model_name)\n    except (ConfigurationError, KeyError, ValueError):\n        # Unknown model — nothing to validate against.\n        return\n    if renderer_name not in recommended:\n        logger.warning(\n            \"Renderer %r is not recommended for model %r. \"\n            \"Recommended renderer(s): %s. \"\n            \"Using an incompatible renderer can silently degrade training quality \"\n            \"(e.g., prefilling tokens the model was never trained on).\",\n            renderer_name,\n            model_name,\n            \", \".join(repr(r) for r in recommended),\n        )\n"
  },
  {
    "path": "tinker_cookbook/model_info_test.py",
    "content": "import logging\n\nimport pytest\n\nfrom tinker_cookbook.model_info import warn_if_renderer_not_recommended\n\n\nclass TestWarnIfRendererNotRecommended:\n    def test_no_warning_when_renderer_is_none(self, caplog: pytest.LogCaptureFixture):\n        with caplog.at_level(logging.WARNING):\n            warn_if_renderer_not_recommended(\"Qwen/Qwen3-4B-Instruct-2507\", None)\n        assert caplog.text == \"\"\n\n    def test_no_warning_when_renderer_is_recommended(self, caplog: pytest.LogCaptureFixture):\n        with caplog.at_level(logging.WARNING):\n            warn_if_renderer_not_recommended(\"Qwen/Qwen3-4B-Instruct-2507\", \"qwen3_instruct\")\n        assert caplog.text == \"\"\n\n    def test_warning_when_renderer_not_recommended(self, caplog: pytest.LogCaptureFixture):\n        with caplog.at_level(logging.WARNING):\n            warn_if_renderer_not_recommended(\n                \"Qwen/Qwen3-4B-Instruct-2507\", \"qwen3_disable_thinking\"\n            )\n        assert \"not recommended\" in caplog.text\n        assert \"qwen3_disable_thinking\" in caplog.text\n        assert \"qwen3_instruct\" in caplog.text\n\n    def test_no_warning_for_unknown_model(self, caplog: pytest.LogCaptureFixture):\n        with caplog.at_level(logging.WARNING):\n            warn_if_renderer_not_recommended(\"unknown/model\", \"qwen3\")\n        assert caplog.text == \"\"\n\n    def test_warning_for_thinking_renderer_on_thinking_model_alt(\n        self, caplog: pytest.LogCaptureFixture\n    ):\n        \"\"\"qwen3_disable_thinking is valid for Qwen3-8B (a thinking model).\"\"\"\n        with caplog.at_level(logging.WARNING):\n            warn_if_renderer_not_recommended(\"Qwen/Qwen3-8B\", \"qwen3_disable_thinking\")\n        assert caplog.text == \"\"\n\n    def test_warning_for_wrong_family(self, caplog: pytest.LogCaptureFixture):\n        \"\"\"llama3 renderer is not recommended for a Qwen model.\"\"\"\n        with caplog.at_level(logging.WARNING):\n            warn_if_renderer_not_recommended(\"Qwen/Qwen3-8B\", \"llama3\")\n        assert \"not recommended\" in caplog.text\n"
  },
  {
    "path": "tinker_cookbook/preference/__init__.py",
    "content": ""
  },
  {
    "path": "tinker_cookbook/preference/comparison_policy_evaluator.py",
    "content": "import asyncio\nfrom collections.abc import Callable, Sequence\nfrom dataclasses import replace\n\nimport numpy as np\nimport tinker\n\nfrom tinker_cookbook.completers import TinkerMessageCompleter\nfrom tinker_cookbook.eval.evaluators import SamplingClientEvaluator\nfrom tinker_cookbook.preference.types import (\n    Comparison,\n    PreferenceModel,\n)\nfrom tinker_cookbook.renderers import get_renderer, get_text_content\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\nclass ComparisonEvaluator(SamplingClientEvaluator):\n    \"\"\"\n    Evaluates a policy by comparing its completions to references, with a reward model\n    \"\"\"\n\n    def __init__(\n        self,\n        preference_model_builder: Callable[[], PreferenceModel],\n        comparisons: Sequence[Comparison],\n        renderer_name: str,\n        model_name_for_tokenizer: str,\n        both_ways: bool = True,\n        max_tokens: int = 1024,\n        content_preprocessor: Callable[[str], str] | None = None,\n    ):\n        self.preference_model_builder = preference_model_builder\n        self.both_ways = both_ways\n        self.comparisons = comparisons\n        self.renderer = get_renderer(renderer_name, get_tokenizer(model_name_for_tokenizer))\n        self.max_tokens = max_tokens\n        if content_preprocessor is None:\n            self.content_preprocessor = lambda x: x\n        else:\n            self.content_preprocessor = content_preprocessor\n\n    async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:\n        preference_model = self.preference_model_builder()\n        policy = TinkerMessageCompleter(sampling_client, self.renderer, self.max_tokens)\n\n        async def process_comparison(comparison: Comparison) -> float:\n            new_completion_message = await policy(comparison.prompt_conversation)\n            new_completion_content = get_text_content(new_completion_message)\n            new_completion_message = {\n                \"role\": \"assistant\",\n                \"content\": self.content_preprocessor(new_completion_content),\n            }\n            new_comparison = replace(comparison, completion_B=[new_completion_message])\n            r_0, r_1 = await asyncio.gather(\n                preference_model(new_comparison), preference_model(new_comparison.swap())\n            )\n            # r_0, r_1 are in between -1 and 1\n            # so r0-r1 is in between -2 and 2, and we normalize it to 0-1\n            return (r_0 - r_1 + 2) / 4.0\n\n        results = await asyncio.gather(\n            *[process_comparison(comparison) for comparison in self.comparisons]\n        )\n        return {\n            \"win_rate\": np.mean(results).item(),\n            \"stderr\": np.std(results).item() / np.sqrt(len(results)),\n        }\n"
  },
  {
    "path": "tinker_cookbook/preference/dpo_datasets.py",
    "content": "import chz\nimport tinker\n\nfrom tinker_cookbook.preference.preference_datasets import (\n    ComparisonDatasetBuilder,\n)\nfrom tinker_cookbook.preference.types import (\n    LabeledComparison,\n)\nfrom tinker_cookbook.supervised.common import datum_from_model_input_weights\nfrom tinker_cookbook.supervised.data import SupervisedDatasetFromHFDataset\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilder, SupervisedDataset\n\n\n@chz.chz\nclass DPODatasetBuilderFromComparisons(ChatDatasetBuilder):\n    \"\"\"\n    DPO dataset builder that uses a ComparisonDatasetBuilder.\n    DPO needs both chosen and rejected examples for training.\n    \"\"\"\n\n    comparison_builder: ComparisonDatasetBuilder\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        train_dataset, test_dataset = self.comparison_builder.get_train_and_test_datasets()\n        renderer = self.renderer\n\n        def comparison_to_datum(labeled_comparison: LabeledComparison) -> list[tinker.Datum]:\n            chosen_completion = (\n                labeled_comparison.comparison.completion_A\n                if labeled_comparison.label == \"A\"\n                else labeled_comparison.comparison.completion_B\n            )\n            rejected_completion = (\n                labeled_comparison.comparison.completion_B\n                if labeled_comparison.label == \"A\"\n                else labeled_comparison.comparison.completion_A\n            )\n\n            chosen_convo = [\n                *labeled_comparison.comparison.prompt_conversation,\n                *chosen_completion,\n            ]\n            rejected_convo = [\n                *labeled_comparison.comparison.prompt_conversation,\n                *rejected_completion,\n            ]\n\n            chosen_tokens, chosen_weights = renderer.build_supervised_example(chosen_convo)\n            rejected_tokens, rejected_weights = renderer.build_supervised_example(rejected_convo)\n\n            return [\n                datum_from_model_input_weights(\n                    chosen_tokens, chosen_weights, self.common_config.max_length\n                ),\n                datum_from_model_input_weights(\n                    rejected_tokens, rejected_weights, self.common_config.max_length\n                ),\n            ]\n\n        def example_to_data(example: dict[str, str]) -> list[tinker.Datum]:\n            labeled_comparison = self.comparison_builder.example_to_labeled_comparison(example)\n            if labeled_comparison is None:\n                return []\n            return comparison_to_datum(labeled_comparison)\n\n        if test_dataset is not None:\n            test_supervised_dataset = SupervisedDatasetFromHFDataset(\n                test_dataset,\n                batch_size=len(test_dataset),\n                flatmap_fn=example_to_data,\n            )\n        else:\n            test_supervised_dataset = None\n\n        return SupervisedDatasetFromHFDataset(\n            train_dataset, batch_size=self.common_config.batch_size, flatmap_fn=example_to_data\n        ), test_supervised_dataset\n"
  },
  {
    "path": "tinker_cookbook/preference/preference_datasets.py",
    "content": "import logging\nimport random\n\nimport chz\nimport datasets\nimport tinker\n\nfrom tinker_cookbook.preference.types import (\n    Comparison,\n    ComparisonRenderer,\n    ComparisonRendererFromChatRenderer,\n    LabeledComparison,\n)\nfrom tinker_cookbook.supervised.common import datum_from_model_input_weights\nfrom tinker_cookbook.supervised.data import SupervisedDatasetFromHFDataset\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilder, SupervisedDataset\n\nlogger = logging.getLogger(__name__)\n\n\n# ============================================================================\n# Base Classes\n# ============================================================================\n\n\n@chz.chz\nclass ComparisonDatasetBuilder:\n    \"\"\"\n    Builds HF datasets and converts to LabeledComparisons.\n    This class is independent of rendering/tokenization.\n    \"\"\"\n\n    swap: bool = False  # do data augmentation by swapping the order of the completions\n\n    def get_train_and_test_datasets(self) -> tuple[datasets.Dataset, datasets.Dataset | None]:\n        \"\"\"Get raw HuggingFace datasets for train and test.\"\"\"\n        raise NotImplementedError\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        \"\"\"Convert a HuggingFace dataset example to a LabeledComparison.\"\"\"\n        raise NotImplementedError\n\n    def get_labeled_comparisons(\n        self,\n    ) -> tuple[list[LabeledComparison], list[LabeledComparison] | None]:\n        \"\"\"Get all labeled comparisons for train and test sets.\"\"\"\n        train_dataset, test_dataset = self.get_train_and_test_datasets()\n\n        # Process train dataset\n        train_comparisons = []\n        for i in range(len(train_dataset)):\n            example = train_dataset[i]\n            labeled_comparison = self.example_to_labeled_comparison(example)\n            if labeled_comparison is not None:\n                train_comparisons.append(labeled_comparison)\n\n        # Process test dataset if it exists\n        test_comparisons = None\n        if test_dataset is not None:\n            test_comparisons = []\n            for i in range(len(test_dataset)):\n                example = test_dataset[i]\n                labeled_comparison = self.example_to_labeled_comparison(example)\n                if labeled_comparison is not None:\n                    test_comparisons.append(labeled_comparison)\n\n        return train_comparisons, test_comparisons\n\n\n@chz.chz\nclass ChatDatasetBuilderFromComparisons(ChatDatasetBuilder):\n    \"\"\"\n    Abstract base for chat dataset builders that use comparisons.\n    Subclasses must implement get_comparison_builder() to provide the dataset-specific logic.\n    \"\"\"\n\n    comparison_builder: ComparisonDatasetBuilder\n    swap: bool = False  # do data augmentation by swapping the order of the completions\n\n    @property\n    def comparison_renderer(self) -> ComparisonRenderer:\n        return ComparisonRendererFromChatRenderer(self.renderer)\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        train_dataset, test_dataset = self.comparison_builder.get_train_and_test_datasets()\n        comparison_renderer = self.comparison_renderer\n        rng = random.Random(0)\n\n        def comparison_to_datum(labeled_comparison: LabeledComparison) -> tinker.Datum:\n            model_input, weights = comparison_renderer.to_model_input_weights(labeled_comparison)\n            return datum_from_model_input_weights(\n                model_input, weights, self.common_config.max_length\n            )\n\n        def example_to_data(example: dict[str, str]) -> list[tinker.Datum]:\n            labeled_comparison = self.comparison_builder.example_to_labeled_comparison(example)\n            if labeled_comparison is None:\n                return []\n            if self.swap:\n                return [\n                    comparison_to_datum(labeled_comparison),\n                    comparison_to_datum(labeled_comparison.swap()),\n                ]\n            else:\n                if rng.random() < 0.5:\n                    labeled_comparison = labeled_comparison.swap()\n                return [comparison_to_datum(labeled_comparison)]\n\n        if test_dataset is not None:\n            test_supervised_dataset = SupervisedDatasetFromHFDataset(\n                test_dataset,\n                batch_size=len(test_dataset),\n                flatmap_fn=example_to_data,\n            )\n        else:\n            test_supervised_dataset = None\n\n        return SupervisedDatasetFromHFDataset(\n            train_dataset,\n            batch_size=self.common_config.batch_size,\n            flatmap_fn=example_to_data,\n        ), test_supervised_dataset\n\n\n@chz.chz\nclass ComparisonBuilderFromJsonl(ComparisonDatasetBuilder):\n    \"\"\"Load LabeledComparisons from JSONL files produced by combine_preference_datasets.py.\"\"\"\n\n    train_path: str\n    test_path: str | None = None\n\n    def get_train_and_test_datasets(self) -> tuple[datasets.Dataset, datasets.Dataset | None]:\n        \"\"\"Load datasets from JSONL files.\"\"\"\n        import json\n\n        import blobfile\n\n        # Load train dataset\n        train_data = []\n        with blobfile.BlobFile(self.train_path, \"r\", streaming=False) as f:\n            for line in f:\n                train_data.append(json.loads(line.strip()))\n\n        train_dataset = datasets.Dataset.from_list(train_data)\n\n        # Load test dataset if provided\n        test_dataset = None\n        if self.test_path:\n            test_data = []\n            with blobfile.BlobFile(self.test_path, \"r\", streaming=False) as f:\n                for line in f:\n                    test_data.append(json.loads(line.strip()))\n            test_dataset = datasets.Dataset.from_list(test_data)\n\n        return train_dataset, test_dataset\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        \"\"\"Convert a dictionary (from JSONL) back to a LabeledComparison.\"\"\"\n        # The JSONL contains the raw LabeledComparison as a dict\n        # with 'comparison' and 'label' keys\n        if \"comparison\" not in example or \"label\" not in example:\n            return None\n\n        comparison_dict = example[\"comparison\"]\n\n        # Reconstruct the Comparison object\n        comparison = Comparison(\n            prompt_conversation=comparison_dict[\"prompt_conversation\"],\n            completion_A=comparison_dict[\"completion_A\"],\n            completion_B=comparison_dict[\"completion_B\"],\n        )\n\n        return LabeledComparison(comparison=comparison, label=example[\"label\"])\n"
  },
  {
    "path": "tinker_cookbook/preference/train_dpo.py",
    "content": "\"\"\"\nDirect Preference Optimization (DPO) training\n\"\"\"\n\nimport asyncio\nimport logging\nfrom pathlib import Path\nfrom typing import cast\n\nimport chz\nimport tinker\nimport torch\nimport torch.nn.functional as F\n\nfrom tinker_cookbook import checkpoint_utils, model_info\nfrom tinker_cookbook.eval.evaluators import Evaluator, EvaluatorBuilder\nfrom tinker_cookbook.supervised.train import run_evals\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilder, SupervisedDataset\nfrom tinker_cookbook.tokenizer_utils import Tokenizer, get_tokenizer\nfrom tinker_cookbook.utils import ml_log, trace\nfrom tinker_cookbook.utils.format_colorized import format_colorized\nfrom tinker_cookbook.utils.lr_scheduling import LRSchedule, compute_schedule_lr_multiplier\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass Config:\n    \"\"\"Configuration for Direct Preference Optimization (DPO) training.\"\"\"\n\n    # Required parameters\n    log_path: str = chz.field(munger=lambda _, s: str(Path(s).expanduser()))\n    model_name: str\n    dataset_builder: ChatDatasetBuilder\n    load_checkpoint_path: str | None = None\n    renderer_name: str | None = None\n    # dataset_builder optionally returns an evaluator (test set)\n\n    # Training parameters\n    learning_rate: float = 1e-5\n    lr_schedule: LRSchedule = \"linear\"\n    num_epochs: int = 1\n    dpo_beta: float = 0.1\n\n    # Model parameters\n    lora_rank: int = 32\n\n    # Infrastructure parameters\n    num_replicas: int = 8\n    base_url: str | None = None\n\n    # Checkpointing and evaluation (0 = disabled for *_every fields)\n    evaluator_builders: list[EvaluatorBuilder] = chz.field(default_factory=list)\n    infrequent_evaluator_builders: list[EvaluatorBuilder] = chz.field(default_factory=list)\n    save_every: int = 20\n    eval_every: int = 10\n    infrequent_eval_every: int = 100\n    ttl_seconds: int | None = 604800  # 7 days\n\n    # Adam optimizer parameters\n    adam_beta1: float = 0.9\n    adam_beta2: float = 0.95\n    adam_eps: float = 1e-8\n\n    # Logging parameters\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    # Profiling\n    enable_trace: bool = False\n    span_chart_every: int = 0\n\n    # DPO-specific parameters\n    reference_model_name: str | None = None\n\n    # Maximum number of training steps. If None, train for num_epochs * n_batches.\n    max_steps: int | None = None\n\n\ndef create_dpo_clients(\n    config: Config,\n    resume_info: checkpoint_utils.CheckpointRecord | None = None,\n    user_metadata: dict[str, str] | None = None,\n) -> tuple[tinker.TrainingClient, tinker.SamplingClient]:\n    \"\"\"Create and configure the training client and reference sampling client for DPO.\n\n    Creates the main training client and a reference sampling client.\n    The reference sampling client is used to compute the reference model's log probabilities\n    for the DPO loss computation more efficiently than a separate training client.\n\n    Args:\n        config: DPO configuration object\n        resume_info: Resume information from checkpoint\n\n    Returns:\n        Tuple of (main training client, reference sampling client)\n    \"\"\"\n    # Create shared service client for both training and reference clients\n    service_client = tinker.ServiceClient(base_url=config.base_url)\n\n    if resume_info:\n        # Resuming interrupted DPO training - load weights + optimizer state\n        assert resume_info.state_path is not None\n        checkpoint_utils.check_renderer_name_for_checkpoint(\n            service_client, resume_info.state_path, config.renderer_name\n        )\n        training_client = service_client.create_training_client_from_state_with_optimizer(\n            resume_info.state_path, user_metadata=user_metadata\n        )\n        logger.info(f\"Resumed DPO training from {resume_info.state_path}\")\n    elif config.load_checkpoint_path:\n        # Starting fresh DPO from checkpoint - load weights only (fresh optimizer)\n        checkpoint_utils.check_renderer_name_for_checkpoint(\n            service_client, config.load_checkpoint_path, config.renderer_name\n        )\n        training_client = service_client.create_training_client_from_state(\n            config.load_checkpoint_path, user_metadata=user_metadata\n        )\n        logger.info(f\"Loaded weights from {config.load_checkpoint_path}\")\n    else:\n        training_client = service_client.create_lora_training_client(\n            base_model=config.model_name, rank=config.lora_rank, user_metadata=user_metadata\n        )\n    # Create a sampling client for the reference model from the training client\n    reference_client = training_client.save_weights_and_get_sampling_client(\"reference\")\n    return training_client, reference_client\n\n\ndef compute_dpo_loss(\n    chosen_logprobs: list[torch.Tensor],\n    rejected_logprobs: list[torch.Tensor],\n    chosen_ref_logprobs: list[torch.Tensor],\n    rejected_ref_logprobs: list[torch.Tensor],\n    dpo_beta: float,\n) -> tuple[torch.Tensor, dict[str, float]]:\n    \"\"\"Compute DPO loss and metrics.\n\n    Args:\n        chosen_logprobs: Log probabilities for chosen responses\n        rejected_logprobs: Log probabilities for rejected responses\n        chosen_ref_logprobs: Reference log probabilities for chosen responses\n        rejected_ref_logprobs: Reference log probabilities for rejected responses\n        dpo_beta: DPO beta parameter\n\n    Returns:\n        Tuple of (loss tensor, metrics dictionary)\n    \"\"\"\n    # Compute log ratios\n    chosen_log_ratio = torch.stack(\n        [lp - rlp for lp, rlp in zip(chosen_logprobs, chosen_ref_logprobs, strict=True)]\n    )\n    rejected_log_ratio = torch.stack(\n        [lp - rlp for lp, rlp in zip(rejected_logprobs, rejected_ref_logprobs, strict=True)]\n    )\n\n    # Compute DPO loss\n    losses = -F.logsigmoid(dpo_beta * (chosen_log_ratio - rejected_log_ratio))\n    loss = losses.mean()\n\n    # Compute metrics\n    accuracy = (chosen_log_ratio > rejected_log_ratio).float().mean().item()\n    chosen_rewards = dpo_beta * chosen_log_ratio\n    rejected_rewards = dpo_beta * rejected_log_ratio\n    margin = (chosen_rewards - rejected_rewards).mean().item()\n\n    metrics = {\n        \"dpo_loss\": loss.item(),\n        \"accuracy\": accuracy,\n        \"margin\": margin,\n        \"chosen_reward\": chosen_rewards.mean().item(),\n        \"rejected_reward\": rejected_rewards.mean().item(),\n    }\n\n    return loss, metrics\n\n\ndef do_update(\n    epoch_idx: int,\n    batch_idx: int,\n    n_batches: int,\n    total_steps: int,\n    config: Config,\n    training_client: tinker.TrainingClient,\n    reference_client: tinker.SamplingClient,\n    evaluators: list[Evaluator],\n    infrequent_evaluators: list[Evaluator],\n    dataset: SupervisedDataset,\n    ml_logger: ml_log.Logger,\n    log_path: str,\n    tokenizer: Tokenizer,\n):\n    \"\"\"Perform a single DPO training update step.\"\"\"\n    step = epoch_idx * n_batches + batch_idx\n    metrics: dict[str, int | float | str] = {\"epoch\": epoch_idx}\n\n    with trace.trace_iteration(step=step) as window:\n        # Save checkpoint if needed\n        if config.save_every > 0 and step % config.save_every == 0 and step > 0:\n            with trace.scope_span_sync(\"save_checkpoint\"):\n                save_result = checkpoint_utils.save_checkpoint(\n                    training_client=training_client,\n                    name=f\"{step:06d}\",\n                    log_path=log_path,\n                    kind=\"both\",\n                    loop_state={\"epoch\": epoch_idx, \"batch\": batch_idx},\n                    ttl_seconds=config.ttl_seconds,\n                )\n            if \"state_path\" in save_result:\n                metrics[\"state_path\"] = save_result[\"state_path\"]\n\n        learning_rate = config.learning_rate * compute_schedule_lr_multiplier(\n            lr_schedule=config.lr_schedule, step=step, total_steps=total_steps\n        )\n        adam_params = tinker.AdamParams(\n            learning_rate=learning_rate,\n            beta1=config.adam_beta1,\n            beta2=config.adam_beta2,\n            eps=config.adam_eps,\n        )\n\n        # Evaluation\n        if config.eval_every > 0 and step % config.eval_every == 0:\n            with trace.scope_span_sync(\"evals\"):\n                eval_metrics = asyncio.run(run_evals(evaluators, training_client, step))\n            metrics.update(eval_metrics)\n\n        if config.infrequent_eval_every > 0 and step % config.infrequent_eval_every == 0:\n            with trace.scope_span_sync(\"infrequent_evals\"):\n                eval_metrics = asyncio.run(run_evals(infrequent_evaluators, training_client, step))\n            metrics.update(eval_metrics)\n\n        # Prepare batch\n        with trace.scope_span_sync(\"get_batch\"):\n            data = dataset.get_batch(batch_idx)\n\n        # Split data into chosen and rejected pairs\n        chosen_data = [datum for i, datum in enumerate(data) if i % 2 == 0]\n        rejected_data = [datum for i, datum in enumerate(data) if i % 2 == 1]\n\n        # Print example for first batch\n        if step == 0:\n            for i in range(min(10, len(chosen_data))):\n                print_example(chosen_data[i], tokenizer, \"Chosen\")\n                print_example(rejected_data[i], tokenizer, \"Rejected\")\n\n        with trace.scope_span_sync(\"get_ref_logprobs\"):\n            # Get reference log probabilities\n            # Need to reconstruct full sequences for the sampling client\n            full_sequences = []\n            for datum in data:\n                # Reconstruct the full sequence by appending the last target token\n                target_tokens = datum.loss_fn_inputs[\"target_tokens\"].data\n                if target_tokens:\n                    full_sequence = datum.model_input.append_int(int(target_tokens[-1]))\n                    full_sequences.append(full_sequence)\n                else:\n                    # If no target tokens, just use the model input as is\n                    full_sequences.append(datum.model_input)\n\n            # Compute reference log probabilities in parallel\n            async def compute_all_ref_logprobs():\n                return await asyncio.gather(\n                    *[reference_client.compute_logprobs_async(seq) for seq in full_sequences]\n                )\n\n            all_ref_logprobs = asyncio.run(compute_all_ref_logprobs())\n\n            # Extract the relevant logprobs (skip the first token which is the prompt)\n            all_ref_logprob_seqs = [torch.tensor(logprobs[1:]) for logprobs in all_ref_logprobs]\n\n            # Split reference results into chosen and rejected\n            chosen_ref_logprob_seqs = [all_ref_logprob_seqs[i] for i in range(0, len(data), 2)]\n            rejected_ref_logprob_seqs = [all_ref_logprob_seqs[i] for i in range(1, len(data), 2)]\n\n        # Create DPO loss function\n        def dpo_loss_fn(\n            data: list[tinker.Datum], logprobs_list: list[torch.Tensor]\n        ) -> tuple[torch.Tensor, dict[str, float]]:\n            # Split logprobs into chosen and rejected\n            chosen_logprob_seqs = [logprobs_list[i] for i in range(0, len(data), 2)]\n            rejected_logprob_seqs = [logprobs_list[i] for i in range(1, len(data), 2)]\n\n            # Extract log probabilities\n            chosen_logprobs = []\n            chosen_ref_logprobs = []\n            rejected_logprobs = []\n            rejected_ref_logprobs = []\n\n            for i in range(len(chosen_data)):\n                # Compute weighted logprobs for chosen responses\n                chosen_logprob_seq = chosen_logprob_seqs[i]\n                chosen_ref_logprob_seq = chosen_ref_logprob_seqs[i]\n                chosen_weights = torch.tensor(chosen_data[i].loss_fn_inputs[\"weights\"].data)\n                chosen_logprob = torch.dot(chosen_logprob_seq.float(), chosen_weights.float())\n                chosen_ref_logprob = torch.dot(\n                    chosen_ref_logprob_seq.float(), chosen_weights.float()\n                )\n                chosen_logprobs.append(chosen_logprob)\n                chosen_ref_logprobs.append(chosen_ref_logprob)\n\n                # Compute weighted logprobs for rejected responses\n                rejected_logprob_seq = rejected_logprob_seqs[i]\n                rejected_ref_logprob_seq = rejected_ref_logprob_seqs[i]\n                rejected_weights = torch.tensor(rejected_data[i].loss_fn_inputs[\"weights\"].data)\n                rejected_logprob = torch.dot(rejected_logprob_seq.float(), rejected_weights.float())\n                rejected_ref_logprob = torch.dot(\n                    rejected_ref_logprob_seq.float(), rejected_weights.float()\n                )\n                rejected_logprobs.append(rejected_logprob)\n                rejected_ref_logprobs.append(rejected_ref_logprob)\n\n            # Compute DPO loss\n            return compute_dpo_loss(\n                chosen_logprobs=chosen_logprobs,\n                rejected_logprobs=rejected_logprobs,\n                chosen_ref_logprobs=chosen_ref_logprobs,\n                rejected_ref_logprobs=rejected_ref_logprobs,\n                dpo_beta=config.dpo_beta,\n            )\n\n        with trace.scope_span_sync(\"step\"):\n            # Do forward-backward with custom DPO loss\n            backward_result = training_client.forward_backward_custom(data, dpo_loss_fn).result()\n            dpo_metrics = backward_result.metrics\n\n            # Optimizer step\n            training_client.optim_step(adam_params).result()\n\n        # Prepare metrics\n        metrics.update(\n            num_pairs=len(chosen_data),\n            num_tokens=sum(datum.model_input.length for datum in data),\n            learning_rate=learning_rate,\n            progress=step / total_steps,\n            **dpo_metrics,\n        )\n\n    # Log timing metrics from trace_iteration window\n    metrics.update(window.get_timing_metrics())\n    window.write_spans_jsonl(Path(log_path) / \"timing_spans.jsonl\", step=step)\n    if config.span_chart_every > 0 and step % config.span_chart_every == 0:\n        trace.save_gantt_chart_html(window, step, Path(log_path) / f\"timing_gantt_{step:06d}.html\")\n    ml_logger.log_metrics(metrics=metrics, step=step)\n\n\ndef main(config: Config):\n    \"\"\"Main training function that runs the complete DPO training process.\"\"\"\n    resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)\n    if resume_info:\n        start_epoch = resume_info.epoch or 0\n        start_batch = resume_info.batch\n    else:\n        start_epoch = 0\n        start_batch = 0\n\n    # Setup\n    ml_logger = ml_log.setup_logging(\n        log_dir=config.log_path,\n        wandb_project=config.wandb_project,\n        wandb_name=config.wandb_name,\n        config=config,\n        do_configure_logging_module=True,\n    )\n    if config.enable_trace:\n        trace_events_path = str(Path(config.log_path) / \"trace_events.jsonl\")\n        logger.info(f\"Tracing is enabled. Trace events will be saved to {trace_events_path}\")\n        logger.info(\n            f\"Run `python tinker_cookbook/utils/trace.py {trace_events_path} trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/\"\n        )\n        trace.trace_init(output_file=trace_events_path)\n\n    user_metadata: dict[str, str] = {}\n    if wandb_link := ml_logger.get_logger_url():\n        user_metadata[\"wandb_link\"] = wandb_link\n    checkpoint_utils.add_renderer_name_to_user_metadata(user_metadata, config.renderer_name)\n    model_info.warn_if_renderer_not_recommended(config.model_name, config.renderer_name)\n    training_client, reference_client = create_dpo_clients(config, resume_info, user_metadata)\n    tokenizer = get_tokenizer(config.model_name)\n\n    # Training setup\n    dataset, maybe_test_dataset = config.dataset_builder()\n    n_batches = len(dataset)\n    total_steps = n_batches * config.num_epochs\n    if config.max_steps is not None:\n        total_steps = min(total_steps, config.max_steps)\n\n    evaluators = [evaluator() for evaluator in config.evaluator_builders]\n    infrequent_evaluators = [evaluator() for evaluator in config.infrequent_evaluator_builders]\n    logger.info(\n        f\"Training for {n_batches} batches x {config.num_epochs} epochs = {n_batches * config.num_epochs} steps\"\n    )\n\n    # Training loop\n    reached_max_steps = False\n    for epoch_idx in range(start_epoch, config.num_epochs):\n        # Shuffle the dataset\n        logger.info(msg=f\"Starting epoch {epoch_idx}\")\n        dataset.set_epoch(seed=epoch_idx)\n\n        for batch_idx in range(start_batch if epoch_idx == start_epoch else 0, n_batches):\n            step = epoch_idx * n_batches + batch_idx\n            if config.max_steps is not None and step >= config.max_steps:\n                reached_max_steps = True\n                break\n            do_update(\n                epoch_idx=epoch_idx,\n                batch_idx=batch_idx,\n                n_batches=n_batches,\n                total_steps=total_steps,\n                config=config,\n                training_client=training_client,\n                reference_client=reference_client,\n                evaluators=evaluators,\n                infrequent_evaluators=infrequent_evaluators,\n                dataset=dataset,\n                ml_logger=ml_logger,\n                log_path=config.log_path,\n                tokenizer=tokenizer,\n            )\n        if reached_max_steps:\n            break\n\n    # Save final checkpoint if training actually happened\n    did_train = start_epoch < config.num_epochs and (\n        config.max_steps is None or start_epoch * n_batches + start_batch < config.max_steps\n    )\n    if did_train:\n        checkpoint_utils.save_checkpoint(\n            training_client=training_client,\n            name=\"final\",\n            log_path=config.log_path,\n            kind=\"both\",\n            loop_state={\"epoch\": config.num_epochs, \"batch\": 0},\n            ttl_seconds=None,\n        )\n    else:\n        logger.info(\"Training was already complete; nothing to do\")\n\n    # Cleanup\n    ml_logger.close()\n    logger.info(\"DPO training completed successfully\")\n\n\ndef print_example(datum: tinker.Datum, tokenizer: Tokenizer, label: str = \"\"):\n    \"\"\"Print a formatted example from the dataset.\"\"\"\n    int_tokens = list(datum.model_input.to_ints())\n    weights = datum.loss_fn_inputs[\"weights\"].data\n    logger.info(f\"\\n{label} Example:\")\n    logger.info(format_colorized(int_tokens, cast(list[float], weights), tokenizer))\n"
  },
  {
    "path": "tinker_cookbook/preference/types.py",
    "content": "\"\"\"\nTypes for preference learning and Direct Preference Optimization (DPO).\n\nThis module defines the core data structures used for preference learning,\nincluding comparisons between model outputs and preference models.\n\"\"\"\n\nimport logging\nfrom dataclasses import dataclass\nfrom typing import Literal\n\nimport chz\nimport tinker\nimport torch\nfrom tinker import SamplingClient, types\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.tokenizer_utils import Tokenizer, get_tokenizer\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass Comparison:\n    prompt_conversation: list[renderers.Message]\n    completion_A: list[renderers.Message]\n    completion_B: list[renderers.Message]\n\n    def swap(self) -> \"Comparison\":\n        return Comparison(\n            prompt_conversation=self.prompt_conversation,\n            completion_A=self.completion_B,\n            completion_B=self.completion_A,\n        )\n\n\n@dataclass\nclass LabeledComparison:\n    comparison: Comparison\n    label: Literal[\"A\", \"B\", \"Tie\"]\n\n    def swap(self) -> \"LabeledComparison\":\n        return LabeledComparison(\n            comparison=self.comparison.swap(),\n            label={\"A\": \"B\", \"B\": \"A\", \"Tie\": \"Tie\"}[self.label],  # pyright: ignore[reportArgumentType]\n        )\n\n\nclass ComparisonRenderer:\n    def build_generation_prompt(self, comparison: Comparison) -> types.ModelInput:\n        raise NotImplementedError\n\n    def to_model_input_weights(\n        self, labeled_comparison: LabeledComparison\n    ) -> tuple[types.ModelInput, torch.Tensor]:\n        raise NotImplementedError\n\n    @property\n    def tokenizer(self) -> Tokenizer:\n        raise NotImplementedError\n\n\nclass ComparisonRendererFromChatRenderer(ComparisonRenderer):\n    # TODO probably shouldn't be in types.py\n    def __init__(self, convo_renderer: renderers.Renderer):\n        self.convo_renderer = convo_renderer\n\n    def _comparison_to_convo(self, comparison: Comparison) -> list[renderers.Message]:\n        return [\n            *comparison.prompt_conversation,\n            {\"role\": \"system\", \"content\": \"==== Completion A ====\"},\n            *comparison.completion_A,\n            {\"role\": \"system\", \"content\": \"==== Completion B ====\"},\n            *comparison.completion_B,\n            {\"role\": \"system\", \"content\": \"==== Preference ====\"},\n        ]\n\n    def build_generation_prompt(self, comparison: Comparison) -> types.ModelInput:\n        return self.convo_renderer.build_generation_prompt(self._comparison_to_convo(comparison))\n\n    def to_model_input_weights(\n        self, labeled_comparison: LabeledComparison\n    ) -> tuple[types.ModelInput, torch.Tensor]:\n        convo = self._comparison_to_convo(labeled_comparison.comparison)\n        convo_with_pref = convo + [{\"role\": \"assistant\", \"content\": labeled_comparison.label}]\n        model_input, weights = self.convo_renderer.build_supervised_example(convo_with_pref)\n        # TODO: support images in preference learning\n        assert all(isinstance(c, tinker.types.EncodedTextChunk) for c in model_input.chunks), (\n            \"Preference learning currently only supports text-only content.\"\n        )\n        # Truncate at the first weight==1 position + 1\n        tokens = model_input.to_ints()\n        first_weight_one_index = int(torch.nonzero(weights == 1.0)[0])\n        truncated_tokens = tokens[: first_weight_one_index + 1]\n        truncated_weights = weights[: first_weight_one_index + 1]\n        return types.ModelInput.from_ints(truncated_tokens), truncated_weights\n\n    @property\n    def tokenizer(self) -> Tokenizer:\n        return self.convo_renderer.tokenizer\n\n\nclass PreferenceModel:\n    async def __call__(self, comparison: Comparison) -> float:\n        \"\"\"\n        1: B is strongly preferred\n        0: Tie\n        -1: A is strongly preferred\n        \"\"\"\n        raise NotImplementedError\n\n\nclass PreferenceModelBuilder:\n    def __call__(self) -> PreferenceModel:\n        raise NotImplementedError\n\n\nclass PreferenceModelFromChatRenderer(PreferenceModel):\n    def __init__(self, convo_renderer: renderers.Renderer, sampling_client: SamplingClient):\n        self.comparison_renderer = ComparisonRendererFromChatRenderer(convo_renderer)\n        self.sampling_client = sampling_client\n\n    async def __call__(self, comparison: Comparison) -> float:\n        pm_input = self.comparison_renderer.build_generation_prompt(comparison)\n        response = await self.sampling_client.sample_async(\n            pm_input,\n            num_samples=1,\n            sampling_params=types.SamplingParams(temperature=0.0, max_tokens=1),\n        )\n        # TODO use probabilities\n        str_output = str(\n            self.comparison_renderer.tokenizer.decode(response.sequences[0].tokens)\n        ).strip()\n        if str_output == \"A\":\n            return -1.0\n        elif str_output == \"B\":\n            return 1.0\n        elif str_output == \"Tie\":\n            return 0.0\n        else:\n            logger.warning(f\"Invalid output preference model output: '{str_output}'\")\n            return 0.0\n\n\n@chz.chz\nclass PreferenceModelBuilderFromChatRenderer(PreferenceModelBuilder):\n    renderer_name: str\n    model_name: str\n    rm_weights_path: str\n    base_url: str | None = None\n\n    def __call__(self) -> PreferenceModel:\n        convo_renderer = renderers.get_renderer(self.renderer_name, get_tokenizer(self.model_name))\n        sampling_client = tinker.ServiceClient(base_url=self.base_url).create_sampling_client(\n            model_path=self.rm_weights_path,\n        )\n        return PreferenceModelFromChatRenderer(convo_renderer, sampling_client)\n"
  },
  {
    "path": "tinker_cookbook/py.typed",
    "content": ""
  },
  {
    "path": "tinker_cookbook/recipes/README.md",
    "content": "# Cookbook Recipes\n\nTinker allows you to flexibly customize your training environment.\nWe will first introduce a few simple training scripts to help you get started, and then cover a broad range of different use cases.\n\n## Getting Started\n\nTinker Cookbook comes with useful abstractions so you can flexibly customize your experiments. Here are some minimal launch scripts:\n- [`rl_basic.py`](./rl_basic.py): a template script to configure reinforcement learning.\n- [`sl_basic.py`](./sl_basic.py): a template script to configure supervised learning.\n\nTo explain what goes under-the-hood, we also provide minimal, self-contained scripts that directly use the TinkerAPI to train LLMs.\n- [`rl_loop.py`](./rl_loop.py): a minimal reinforcement learning training loop.\n- [`sl_loop.py`](./sl_loop.py): a minimal supervised learning training loop.\n\n## More Post-Training Examples\n\nBuilding on Tinker and Tinker Cookbook, we can easily customize a wide range of training environments for LLMs.\nWe provide the following examples:\n- **[Chat supervised learning](./chat_sl/)**: supervised fine-tuning on conversational datasets like Tulu3.\n- **[Math reasoning](./math_rl/)**: improve LLM reasoning capability by rewarding it for answering math questions correctly.\n- **[Code reasoning](./code_rl/)**: train LLMs on competitive programming problems with sandboxed code execution (DeepCoder replication).\n- **[Preference learning](./preference/)**: showcase a three-stage RLHF pipeline: 1) supervised fine-tuning, 2) learning a reward model, 3) RL against the reward model.\n- **[Tool use](./search_tool/)**: train LLMs to better use retrieval tools to answer questions more accurately.\n- **[Prompt distillation](./prompt_distillation/)**: internalize long and complex instructions into LLMs.\n- **[Multi-Agent](./multiplayer_rl/)**: optimize LLMs to play against another LLM or themselves.\n- **[Model distillation](./distillation/)**: use on-policy distillation or SFT to distill intelligence from a teacher model.\n- **[Rubric-based grading](./rubric/)**: use an LLM grader with rubrics to provide rewards for RL training.\n- **[Verifiers environments](./verifiers_rl/)**: use RL environments from Prime Intellect's Environments Hub with Tinker.\n- **[VLM image classification](./vlm_classifier/)**: train vision-language models as image classifiers.\n- **[Harbor RL](./harbor_rl/)**: RL training on Harbor-formatted tasks (e.g., Terminal-Bench) with sandboxed code execution.\n\nThese examples are located in each subfolder, and their `README.md` file will walk you through the key implementation details, the commands to run them, and the expected performance.\n\n### Logging and Recovering From Training Interruptions\n\nOur examples support the following CLI arguments to log the results.\n\n1. `wandb_project`: When provided, logs will be sent to your Weights & Biases project. Without this argument, training scripts save logs locally only.\n2. `log_path`: Controls where training artifacts are saved.\n  - Default behavior: If not specified, each run generates a unique name and saves to `/tmp/tinker-examples`\n  - Output files:\n    - `{log_path}/metrics.jsonl` saves training metrics.\n    - `{log_path}/checkpoints.jsonl` records all the checkpoints saved during training. You can share these checkpoints for model release, offline evaluation, etc.\n  - Resuming: When using an existing `log_path`, you can either overwrite the previous run or resume training. This is particularly useful for recovering from runtime interruptions.\n"
  },
  {
    "path": "tinker_cookbook/recipes/chat_sl/README.md",
    "content": "# Supervised Learning\n\n## SFT on NoRobots\n\n```bash\npython -m tinker_cookbook.recipes.chat_sl.train \\\n    model_name=Qwen/Qwen3-8B-Base \\\n    dataset=no_robots \\\n    learning_rate=5e-4 \\\n    batch_size=64 \\\n    lora_rank=64 \\\n    eval_every=20 \\\n    save_every=20 \\\n    wandb_project=cookbook_sl\n```\n\nAfter 140 steps of training, `test/nll` decreases to 1.788.\n\n## SFT on Tulu3 dataset\n\n```bash\npython -m tinker_cookbook.recipes.chat_sl.train \\\n    model_name=Qwen/Qwen3-8B-Base \\\n    dataset=tulu3 \\\n    learning_rate=5e-4 \\\n    batch_size=128 \\\n    lora_rank=64 \\\n    eval_every=500 \\\n    save_every=500 \\\n    wandb_project=cookbook_sl\n```\n\nAfter 1740 steps of training, `test/nll` decreases to 0.50.\nPerformance can be further improved by training longer with a higher `lora_rank` and lower `batch_size`.\n\n## Adding your own dataset\n\nThe base classes in [tinker_cookbook/supervised/data.py](../../supervised/data.py) support loading new data in the following way:\n- `SupervisedDatasetFromHFDataset` loads dataset on Hugging Face hub with a postprocessing function\n- `StreamingSupervisedDatasetFromHFDataset` works similarly, but supports streaming\n- `FromConversationFileBuilder` supports data loading from a JSONL file\n"
  },
  {
    "path": "tinker_cookbook/recipes/chat_sl/chat_datasets.py",
    "content": "\"\"\"\nDatasets for supervised learning (SFT) that use chat-formatted data, which we\nconvert to tokens using a Renderer.\n\"\"\"\n\nimport logging\nfrom typing import cast\n\nimport chz\nimport datasets\nimport tinker\n\nfrom tinker_cookbook.renderers import TrainOnWhat\nfrom tinker_cookbook.supervised.data import (\n    SupervisedDatasetFromHFDataset,\n    conversation_to_datum,\n)\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilder, SupervisedDataset\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass Tulu3Builder(ChatDatasetBuilder):\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset]:\n        dataset = datasets.load_dataset(\"allenai/tulu-3-sft-mixture\")\n        dataset = cast(datasets.DatasetDict, dataset)\n        dataset = dataset[\"train\"]\n        dataset = dataset.shuffle(seed=0)\n        test_ds = dataset.take(1024)\n        train_ds = dataset.skip(1024)\n\n        # Use train_on_what from common_config if provided, otherwise default to LAST_ASSISTANT_MESSAGE\n        train_on_what = (\n            TrainOnWhat(self.common_config.train_on_what)\n            if self.common_config.train_on_what\n            else TrainOnWhat.LAST_ASSISTANT_MESSAGE\n        )\n\n        # take the last 1000 as test, the rest as train\n        def map_fn(row: dict) -> tinker.Datum:\n            return conversation_to_datum(\n                row[\"messages\"], self.renderer, self.common_config.max_length, train_on_what\n            )\n\n        return SupervisedDatasetFromHFDataset(\n            train_ds, batch_size=self.common_config.batch_size, map_fn=map_fn\n        ), SupervisedDatasetFromHFDataset(\n            test_ds, batch_size=self.common_config.batch_size, map_fn=map_fn\n        )\n\n\n@chz.chz\nclass NoRobotsBuilder(ChatDatasetBuilder):\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset]:\n        dataset = datasets.load_dataset(\"HuggingFaceH4/no_robots\")\n        dataset = cast(datasets.DatasetDict, dataset)\n        train_dataset = dataset[\"train\"]\n        test_dataset = dataset[\"test\"]\n        train_dataset = train_dataset.shuffle(seed=0)\n\n        # Use train_on_what from common_config if provided, otherwise use default\n        train_on_what = (\n            TrainOnWhat(self.common_config.train_on_what)\n            if self.common_config.train_on_what\n            else TrainOnWhat.ALL_ASSISTANT_MESSAGES\n        )\n\n        def map_fn(row: dict) -> tinker.Datum:\n            return conversation_to_datum(\n                row[\"messages\"], self.renderer, self.common_config.max_length, train_on_what\n            )\n\n        return SupervisedDatasetFromHFDataset(\n            train_dataset, batch_size=self.common_config.batch_size, map_fn=map_fn\n        ), SupervisedDatasetFromHFDataset(\n            test_dataset, batch_size=self.common_config.batch_size, map_fn=map_fn\n        )\n"
  },
  {
    "path": "tinker_cookbook/recipes/chat_sl/train.py",
    "content": "\"\"\"\nBasic CLI for training with supervised learning. Currently only used for integration tests.\n\n\"\"\"\n\nimport asyncio\nfrom datetime import datetime\n\nimport chz\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils, renderers\nfrom tinker_cookbook.eval.evaluators import EvaluatorBuilder\nfrom tinker_cookbook.recipes.chat_sl import chat_datasets\nfrom tinker_cookbook.supervised import train\nfrom tinker_cookbook.supervised.data import FromConversationFileBuilder\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilder, ChatDatasetBuilderCommonConfig\nfrom tinker_cookbook.utils.lr_scheduling import LRSchedule\n\n\n@chz.chz\nclass CLIConfig:\n    # Required parameters\n    log_path: str | None = None\n    model_name: str = \"meta-llama/Llama-3.1-8B\"\n    load_checkpoint_path: str | None = None\n    dataset: str = \"no_robots\"\n\n    # Training parameters\n    learning_rate: float = 1e-4\n    lr_schedule: LRSchedule = \"linear\"\n    num_epochs: int = 1\n\n    # Model parameters\n    lora_rank: int = 32\n\n    # Infrastructure parameters\n    base_url: str | None = None\n\n    # Checkpointing and evaluation\n    save_every: int = 20\n    eval_every: int = 20\n    infrequent_eval_every: int = 100\n    inline_evals: str | None = None\n\n    # Dataset-specific parameters\n    renderer_name: str | None = None\n    train_on_what: renderers.TrainOnWhat | None = None  # TrainOnWhat option\n    max_length: int | None = 16384\n    batch_size: int = 256\n\n    # Logging parameters\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\ndef get_dataset_builder(\n    dataset: str,\n    model_name: str,\n    renderer_name: str,\n    max_length: int | None,\n    batch_size: int,\n    train_on_what: renderers.TrainOnWhat | None = None,\n) -> ChatDatasetBuilder:\n    # Note that sft/train can work with non-chat datasets, but this CLI only supports chat datasets\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=model_name,\n        renderer_name=renderer_name,\n        max_length=max_length,\n        batch_size=batch_size,\n        train_on_what=train_on_what,\n    )\n\n    if dataset == \"tulu3\":\n        return chat_datasets.Tulu3Builder(common_config=common_config)\n    elif dataset == \"no_robots\":\n        return chat_datasets.NoRobotsBuilder(common_config=common_config)\n    elif dataset.endswith(\".jsonl\"):\n        # Load conversations from a JSONL file\n        return FromConversationFileBuilder(\n            common_config=common_config,\n            file_path=dataset,\n        )\n    else:\n        raise ValueError(f\"Unknown dataset: {dataset}\")\n\n\ndef get_infrequent_evaluator_builders(\n    inline_evals: str | None, renderer_name: str, model_name: str\n) -> list[EvaluatorBuilder]:\n    if inline_evals is None:\n        return []\n    elif inline_evals == \"inspect\":\n        from tinker_cookbook.eval.inspect_evaluators import InspectEvaluatorBuilder\n\n        builder = InspectEvaluatorBuilder(\n            tasks=[\"inspect_evals/gsm8k\", \"inspect_evals/ifeval\"],\n            renderer_name=renderer_name,\n            model_name=model_name,\n            temperature=0.6,\n            max_tokens=1000,\n            limit=None,\n            debug_errors=True,\n            log_dir=None,\n            max_connections=512,\n            log_level=\"INFO\",\n        )\n        return [builder]\n    else:\n        raise ValueError(f\"Unknown inline evaluator: {inline_evals}\")\n\n\ndef cli_main(cli_config: CLIConfig):\n    # build full config\n    model_name = cli_config.model_name.replace(\"/\", \"-\")\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    run_name = f\"{cli_config.dataset}-{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.batch_size}batch-{date_and_time}\"\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/chat_sl/{run_name}\"\n\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n    renderer_name = checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n    config = train.Config(\n        log_path=log_path,\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        dataset_builder=get_dataset_builder(\n            cli_config.dataset,\n            cli_config.model_name,\n            renderer_name,\n            cli_config.max_length,\n            cli_config.batch_size,\n            cli_config.train_on_what,\n        ),\n        evaluator_builders=[],\n        infrequent_evaluator_builders=get_infrequent_evaluator_builders(\n            cli_config.inline_evals,\n            renderer_name,\n            cli_config.model_name,\n        ),\n        learning_rate=cli_config.learning_rate,\n        lr_schedule=cli_config.lr_schedule,\n        num_epochs=cli_config.num_epochs,\n        base_url=cli_config.base_url,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        lora_rank=cli_config.lora_rank,\n        save_every=cli_config.save_every,\n        eval_every=cli_config.eval_every,\n        infrequent_eval_every=cli_config.infrequent_eval_every,\n        max_steps=cli_config.max_steps,\n    )\n    asyncio.run(train.main(config))\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    cli_main(cli_config)\n"
  },
  {
    "path": "tinker_cookbook/recipes/code_rl/README.md",
    "content": "# Replicating DeepCoder with Tinker\n\nCompetitive programming problems are a common testbed for RL with LLMs. The recent [DeepCoder](https://pretty-radio-b75.notion.site/DeepCoder-A-Fully-Open-Source-14B-Coder-at-O3-mini-Level-1cf81902c14680b3bee5eb349a512a51) blog post introduces a dataset and training pipeline for this purpose. This recipe demonstrates a similar setup using `Qwen3-4B-Instruct-2507`.\n\n## Running This Demo\n\n### Sandboxing\n\nSandboxing is essential for safely executing generated code during training and evaluation. Two sandbox backends are supported:\n\n#### SandboxFusion (Default)\n\n[Sandbox Fusion](https://bytedance.github.io/SandboxFusion/) provides local Docker-based sandboxing. You can start a local sandbox in Docker with:\n\n```bash\ndocker run -it -p 8080:8080 \\\n    -v ${TINKER_COOKBOOK_ROOT}/tinker_cookbook/recipes/code_rl/sandbox_config/local.yaml:/root/sandbox/sandbox/configs/local.yaml \\\n    volcengine/sandbox-fusion:server-20250609\n```\n\nHere, `${TINKER_COOKBOOK_ROOT}` is the absolute path to your local `tinker-cookbook` repository. The training script reads the sandbox endpoint from the `SANDBOX_URL` environment variable. By default it uses `http://localhost:8080/run_code`. Example:\n\n```bash\nexport SANDBOX_URL=http://localhost:8080/run_code\n```\n\nIf you prefer not to use Docker, you can set up the sandbox manually by following the instructions in the [Sandbox Fusion repository](https://github.com/bytedance/SandboxFusion?tab=readme-ov-file#installation).\n\n#### Modal (Alternative)\n\n[Modal](https://modal.com/docs/guide/sandbox) provides cloud-based sandboxed execution without local Docker setup. To use Modal:\n\n1. Install the modal extra and authenticate:\n```bash\nuv pip install 'tinker-cookbook[modal] @ git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly'\nmodal token new\n```\n\n2. Set the sandbox backend in your training command:\n```bash\npython -m tinker_cookbook.recipes.code_rl.train \\\n    sandbox_backend=modal \\\n    ...\n```\n\nOptional environment variables for Modal:\n- `MODAL_POOL_SIZE`: Number of concurrent sandboxes (default: 32)\n- `MODAL_CREATION_RATE_LIMIT`: Max sandboxes created per second (default: 4)\n\n### Example command\n\nTrain a `Qwen3-4B-Instruct-2507` model with:\n\n```bash\npython -m tinker_cookbook.recipes.code_rl.train \\\n    model_name=\"Qwen/Qwen3-4B-Instruct-2507\" \\\n    group_size=8 groups_per_batch=128 \\\n    learning_rate=4e-5 \\\n    lora_rank=32 \\\n    max_tokens=24576\n```\n\nAfter 100 steps of training, you can expect the following performance on **LiveCodeBench v6 (2025.02–2025.05)**:\n\n| Model | Pass@1 | Pass@8 |\n|-------|--------|--------|\n| Qwen3-4B-Instruct-2507 (before training) | 33.8% | 44.3% |\n| Qwen3-4B-Instruct-2507 (after 100 steps) | 42.7% | 55.0% |\n\n[1] Luo, M., Tan, S., Huang, R., Patel, A., Ariyak, A., Wu, Q., Shi, X., Xin, R., Cai, C., Weber, M., Zhang, C., Li, L. E., Popa, R. A., & Stoica, I. (2025). DeepCoder: A fully open-source 14B coder at O3-mini level.\n"
  },
  {
    "path": "tinker_cookbook/recipes/code_rl/code_env.py",
    "content": "from __future__ import annotations\n\nimport json\nimport logging\nfrom collections.abc import Sequence\nfrom typing import Any, Literal, cast\n\nimport chz\nfrom datasets import Dataset, concatenate_datasets, load_dataset\n\nfrom tinker_cookbook import model_info, tokenizer_utils\nfrom tinker_cookbook.recipes.code_rl.code_grading import taco_to_lcb_format\nfrom tinker_cookbook.recipes.code_rl.deepcoder_tool import (\n    DeepcoderReward,\n    DeepcoderTask,\n    DeepcoderTool,\n)\nfrom tinker_cookbook.recipes.code_rl.lcb_utils import fetch_live_code_bench_system_prompt\nfrom tinker_cookbook.renderers import get_renderer\nfrom tinker_cookbook.renderers.base import Message, Renderer\nfrom tinker_cookbook.rl.types import Env, EnvGroupBuilder, RLDataset, RLDatasetBuilder\nfrom tinker_cookbook.sandbox import SandboxBackend\nfrom tinker_cookbook.tool_use import build_agent_tool_env\n\nlogger = logging.getLogger(__name__)\n\n\ndef _load_deepcoder_split(split: Literal[\"train\", \"test\"]) -> Dataset:\n    logger.info(\"Loading DeepCoder dataset split: %s\", split)\n    if split == \"train\":\n        names = (\"primeintellect\", \"taco\", \"lcbv5\")\n    else:\n        names = (\"codeforces\", \"lcbv5\")\n\n    datasets = []\n    for name in names:\n        logger.info(f\"  Loading {name}...\")\n        ds = load_dataset(\"agentica-org/DeepCoder-Preview-Dataset\", name=name, split=split)\n        datasets.append(cast(Dataset, ds))\n\n    return cast(Dataset, concatenate_datasets(datasets))\n\n\ndef _ensure_dict(metadata: Any) -> dict[str, Any]:\n    if isinstance(metadata, str):\n        try:\n            metadata = json.loads(metadata)\n        except json.JSONDecodeError:\n            logger.warning(\"Failed to deserialize metadata: %s\", metadata)\n            return {}\n    if isinstance(metadata, dict):\n        return metadata\n    return {}\n\n\ndef _normalize_tests(raw_tests: Any, metadata: dict[str, Any]) -> list[dict[str, Any]]:\n    \"\"\"Normalize test cases to a unified format.\"\"\"\n    tests = raw_tests\n    if isinstance(tests, str):\n        try:\n            tests = json.loads(tests)\n        except json.JSONDecodeError:\n            logger.warning(\"Failed to deserialize tests. Dropping sample.\")\n            return []\n    if isinstance(tests, dict) and \"inputs\" in tests and \"outputs\" in tests:\n        tests = taco_to_lcb_format(tests)\n    if isinstance(tests, dict):\n        tests = [tests]\n\n    normalized: list[dict[str, Any]] = []\n    for test in tests or []:\n        if not isinstance(test, dict):\n            continue\n        testtype = test.get(\"testtype\") or \"stdin_stdout\"\n        test_metadata = _ensure_dict(test.get(\"metadata\", {}))\n        if testtype == \"functional\":\n            func_name = test_metadata.get(\"func_name\") or metadata.get(\"func_name\")\n            if func_name is not None:\n                test_metadata[\"func_name\"] = str(func_name)\n        normalized.append(\n            {\n                \"input\": str(test.get(\"input\", \"\")),\n                \"output\": str(test.get(\"output\", \"\")),\n                \"testtype\": testtype,\n                \"metadata\": test_metadata or {\"func_name\": None},\n            }\n        )\n    return normalized\n\n\ndef _build_question(example: dict[str, Any]) -> str | None:\n    \"\"\"Build the question text with LCB system prompt.\"\"\"\n    question = example.get(\"question\") or example.get(\"prompt\") or example.get(\"problem\")\n    if not isinstance(question, str) or not question.strip():\n        return None\n    starter_code = example.get(\"starter_code\")\n    if isinstance(starter_code, str) and starter_code.strip():\n        return fetch_live_code_bench_system_prompt(question, starter_code)\n    return fetch_live_code_bench_system_prompt(question)\n\n\ndef load_deepcoder_tasks(\n    split: Literal[\"train\", \"test\"] = \"train\",\n    seed: int = 0,\n) -> list[DeepcoderTask]:\n    \"\"\"Load tasks from the DeepCoder dataset.\n\n    Args:\n        split: Which split to load (\"train\" or \"test\")\n        seed: Random seed for shuffling (train split only)\n\n    Returns:\n        List of DeepcoderTask instances with normalized test cases\n    \"\"\"\n    ds: Dataset = _load_deepcoder_split(split)\n    if split == \"train\":\n        ds = ds.shuffle(seed=seed)\n\n    logger.info(f\"Processing {len(ds)} examples into tasks...\")\n    tasks: list[DeepcoderTask] = []\n    for item in ds:\n        row = cast(dict[str, Any], item)\n\n        # Extract and normalize metadata\n        metadata = _ensure_dict(row.get(\"metadata\", {}))\n\n        # Normalize test cases\n        raw_tests = row.get(\"tests\") or row.get(\"ground_truth\")\n        tests = _normalize_tests(raw_tests, metadata)\n        if not tests:\n            continue\n\n        # Build problem prompt\n        problem = _build_question(row)\n        if problem is None:\n            continue\n\n        # Extract starter code if present\n        starter_code = row.get(\"starter_code\")\n        if isinstance(starter_code, str) and not starter_code.strip():\n            starter_code = None\n\n        tasks.append(\n            DeepcoderTask(\n                problem=problem,\n                tests=tests,\n                starter_code=starter_code if isinstance(starter_code, str) else None,\n            )\n        )\n\n    return tasks\n\n\ndef _initial_messages(\n    task: DeepcoderTask,\n    renderer: Renderer,\n    code_tool: DeepcoderTool,\n) -> list[Message]:\n    \"\"\"Build initial messages with tool schemas and task problem.\n\n    Note: task.problem already contains the full LCB system prompt (via _build_question),\n    including starter code if present. The renderer adds tool-specific formatting\n    automatically via create_conversation_prefix_with_tools().\n    \"\"\"\n    tool_schemas = [code_tool.check_solution.to_spec()]\n    prefix = renderer.create_conversation_prefix_with_tools(tools=tool_schemas)\n    return prefix + [{\"role\": \"user\", \"content\": task.problem}]\n\n\n@chz.chz\nclass DeepcoderEnvGroupBuilder(EnvGroupBuilder):\n    \"\"\"EnvGroupBuilder that creates code environments with shared sandbox backend.\"\"\"\n\n    task: DeepcoderTask\n    model_name: str\n    renderer_name: str | None\n    max_turns: int\n    group_size: int\n    sandbox_backend: SandboxBackend | None\n    timeout: int = 6\n    format_coef: float = 0.1\n    max_trajectory_tokens: int = 32 * 1024\n\n    async def make_envs(self) -> Sequence[Env]:\n        # Renderer is stateless, share across all envs in group\n        tokenizer = tokenizer_utils.get_tokenizer(self.model_name)\n        renderer_name = self.renderer_name or model_info.get_recommended_renderer_name(\n            self.model_name\n        )\n        renderer = get_renderer(renderer_name, tokenizer)\n\n        envs = []\n        for _ in range(self.group_size):\n            tool = DeepcoderTool(self.task, self.sandbox_backend, self.timeout)\n            envs.append(\n                build_agent_tool_env(\n                    renderer=renderer,\n                    tools=[tool.check_solution],\n                    initial_messages=_initial_messages(self.task, renderer, tool),\n                    reward_fn=DeepcoderReward(\n                        task=self.task,\n                        sandbox_backend=self.sandbox_backend,\n                        timeout=self.timeout,\n                        format_coef=self.format_coef,\n                    ),\n                    max_trajectory_tokens=self.max_trajectory_tokens,\n                    max_turns=self.max_turns,\n                )\n            )\n        return envs\n\n    def logging_tags(self) -> list[str]:\n        return [\"deepcoder\"]\n\n\nclass DeepcoderDataset(RLDataset):\n    \"\"\"Dataset that processes code EnvGroupBuilders once per epoch.\"\"\"\n\n    def __init__(\n        self,\n        env_group_builders: list[DeepcoderEnvGroupBuilder],\n        batch_size: int,\n    ):\n        self.env_group_builders = env_group_builders\n        self.batch_size = batch_size\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        start = index * self.batch_size\n        end = start + self.batch_size\n        return self.env_group_builders[start:end]\n\n    def __len__(self) -> int:\n        return (len(self.env_group_builders) + self.batch_size - 1) // self.batch_size\n\n\n@chz.chz\nclass DeepcoderDatasetBuilder(RLDatasetBuilder):\n    \"\"\"Build an RL dataset over DeepCoder tasks.\"\"\"\n\n    model_name_for_tokenizer: str\n    batch_size: int\n    group_size: int\n    renderer_name: str | None = None\n    max_turns: int = 2\n    format_coef: float = 0.1\n    timeout: int = 6\n    sandbox_backend: SandboxBackend | None = None\n    seed: int = 0\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset | None]:\n        # Load train tasks\n        train_tasks = load_deepcoder_tasks(\"train\", seed=self.seed)\n        train_builders = [\n            DeepcoderEnvGroupBuilder(\n                task=task,\n                model_name=self.model_name_for_tokenizer,\n                renderer_name=self.renderer_name,\n                max_turns=self.max_turns,\n                group_size=self.group_size,\n                sandbox_backend=self.sandbox_backend,\n                timeout=self.timeout,\n                format_coef=self.format_coef,\n            )\n            for task in train_tasks\n        ]\n        train_dataset = DeepcoderDataset(\n            env_group_builders=train_builders,\n            batch_size=self.batch_size,\n        )\n\n        # Load test tasks (group_size=1 for eval)\n        test_tasks = load_deepcoder_tasks(\"test\", seed=self.seed)\n        test_builders = [\n            DeepcoderEnvGroupBuilder(\n                task=task,\n                model_name=self.model_name_for_tokenizer,\n                renderer_name=self.renderer_name,\n                max_turns=self.max_turns,\n                group_size=1,  # Single sample per task for evaluation\n                sandbox_backend=self.sandbox_backend,\n                timeout=self.timeout,\n                format_coef=self.format_coef,\n            )\n            for task in test_tasks\n        ]\n        test_dataset = DeepcoderDataset(\n            env_group_builders=test_builders,\n            batch_size=self.batch_size,\n        )\n\n        return train_dataset, test_dataset\n"
  },
  {
    "path": "tinker_cookbook/recipes/code_rl/code_grading.py",
    "content": "\"\"\"\nCode grading utilities for RL training.\n\nSupports two execution backends:\n- sandboxfusion: Local Docker-based sandbox (default)\n- modal: Cloud-based Modal sandbox\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport re\nfrom typing import Any\n\nfrom tinker_cookbook.recipes.code_rl.lcb_utils import TEST_CODE, TEST_UTIL\nfrom tinker_cookbook.sandbox import SandboxBackend, SandboxFusionClient\n\n# Global sandbox backend clients (lazily initialized)\n_sandboxfusion_client: SandboxFusionClient | None = None\n_modal_pool: Any = None  # ModalSandboxPool, but avoid import at module level\n\n\ndef _get_sandboxfusion_client() -> SandboxFusionClient:\n    \"\"\"Get or create the SandboxFusion client.\"\"\"\n    global _sandboxfusion_client\n    if _sandboxfusion_client is None:\n        _sandboxfusion_client = SandboxFusionClient()\n    return _sandboxfusion_client\n\n\ndef _get_modal_pool():\n    \"\"\"Get or create the Modal sandbox pool.\"\"\"\n    global _modal_pool\n    if _modal_pool is None:\n        import modal\n\n        from tinker_cookbook.sandbox.modal_sandbox import ModalSandboxPool\n\n        image = modal.Image.debian_slim().pip_install(\"numpy\")\n        _modal_pool = ModalSandboxPool(image=image)\n    return _modal_pool\n\n\ndef extract_code_from_model(model_response: str) -> str | None:\n    \"\"\"Extract the last fenced code block from a model response.\"\"\"\n    code_blocks = re.findall(r\"```(?:\\w+)?\\n(.*?)```\", model_response, re.DOTALL)\n    if not code_blocks:\n        return None\n    return code_blocks[-1].strip()\n\n\ndef postprocess_lcb_sample(sample: list[dict[str, Any]]) -> dict[str, str]:\n    \"\"\"Convert test cases to LiveCodeBench format for the test runner.\"\"\"\n    sample_inputs = [item[\"input\"] for item in sample]\n    sample_outputs = [item[\"output\"] for item in sample]\n\n    sample_dict: dict[str, Any] = {\n        \"inputs\": sample_inputs,\n        \"outputs\": sample_outputs,\n    }\n\n    if sample[0].get(\"testtype\") == \"functional\":\n        metadata = sample[0].get(\"metadata\", {})\n        fn_name = metadata.get(\"func_name\")\n        if fn_name is None:\n            raise AssertionError(f\"Function name missing in metadata: {metadata}. Sample: {sample}\")\n        sample_dict[\"fn_name\"] = fn_name\n\n    return {\n        \"input_output\": json.dumps(sample_dict),\n    }\n\n\nasync def _check_with_sandboxfusion(\n    test_cases: dict[str, str],\n    generation: str,\n    timeout: int,\n    total_timeout: int,\n) -> tuple[bool, dict[str, Any]]:\n    \"\"\"Execute tests using SandboxFusion backend.\"\"\"\n    client = _get_sandboxfusion_client()\n\n    return await client.run(\n        code=TEST_CODE % {\"timeout\": timeout},\n        files={\n            \"test_cases.txt\": json.dumps(test_cases),\n            \"code.py\": generation,\n            \"testing_util.py\": TEST_UTIL,\n        },\n        timeout=total_timeout,\n    )\n\n\nasync def _check_with_modal(\n    test_cases: dict[str, str],\n    generation: str,\n    timeout: int,\n    total_timeout: int,\n) -> tuple[bool, dict[str, Any]]:\n    \"\"\"Execute tests using Modal sandbox.\"\"\"\n    pool = _get_modal_pool()\n    result = await pool.run_in_workdir(\n        files={\n            \"test_cases.txt\": json.dumps(test_cases),\n            \"code.py\": generation,\n            \"testing_util.py\": TEST_UTIL,\n            \"run.py\": TEST_CODE % {\"timeout\": timeout},\n        },\n        command=[\"python\", \"run.py\"],\n        timeout=total_timeout,\n    )\n    return result.exit_code == 0, {\n        \"exit_code\": result.exit_code,\n        \"stdout\": result.stdout,\n        \"stderr\": result.stderr,\n    }\n\n\nasync def sandbox_check_correctness(\n    sample: list[dict[str, Any]],\n    generation: str,\n    timeout: int = 6,\n    backend: SandboxBackend | None = None,\n) -> tuple[bool, dict[str, Any]]:\n    \"\"\"\n    Check correctness of generated code using sandbox execution.\n\n    Args:\n        sample: List of test cases in LiveCodeBench format\n        generation: Generated code to test\n        timeout: Per-test timeout in seconds\n        backend: Sandbox backend to use (defaults to \"sandboxfusion\")\n\n    Returns:\n        Tuple of (all_passed: bool, details: dict)\n    \"\"\"\n    assert len(sample) >= 1, \"Sample must contain at least one test case\"\n\n    # Process test cases\n    test_cases = postprocess_lcb_sample(sample)\n    use_backend = backend or SandboxBackend.SANDBOXFUSION\n\n    try:\n        test_cnt = len(json.loads(test_cases[\"input_output\"])[\"inputs\"])\n        total_timeout = (timeout + 1) * test_cnt + 5\n\n        if use_backend == SandboxBackend.MODAL:\n            return await _check_with_modal(test_cases, generation, timeout, total_timeout)\n        elif use_backend == SandboxBackend.SANDBOXFUSION:\n            return await _check_with_sandboxfusion(test_cases, generation, timeout, total_timeout)\n        else:\n            raise ValueError(f\"Invalid sandbox backend: {use_backend}\")\n\n    except Exception as e:\n        return False, {\"error\": str(e)}\n\n\ndef taco_to_lcb_format(tests: dict[str, Any]) -> list[dict[str, Any]]:\n    \"\"\"Convert TACO-style tests to LiveCodeBench format.\"\"\"\n    inputs = tests.get(\"inputs\", [])\n    outputs = tests.get(\"outputs\", [])\n\n    n = max(len(inputs), len(outputs))\n\n    test_cases: list[dict[str, Any]] = []\n    for i in range(n):\n        inp = inputs[i] if i < len(inputs) else (inputs[0] if inputs else \"\")\n        out = outputs[i] if i < len(outputs) else (outputs[0] if outputs else \"\")\n        if isinstance(out, list):\n            out = out[0] if out else \"\"\n        case: dict[str, Any] = {\n            \"input\": inp,\n            \"output\": out,\n            \"metadata\": {},\n        }\n        if \"fn_name\" in tests:\n            case[\"testtype\"] = \"functional\"\n            case[\"metadata\"][\"func_name\"] = tests[\"fn_name\"]\n        else:\n            case[\"testtype\"] = \"stdin_stdout\"\n        test_cases.append(case)\n\n    return test_cases\n"
  },
  {
    "path": "tinker_cookbook/recipes/code_rl/deepcoder_tool.py",
    "content": "from __future__ import annotations\n\nimport json\nfrom dataclasses import dataclass\nfrom typing import Annotated, Any\n\nfrom tinker_cookbook.recipes.code_rl.code_grading import (\n    extract_code_from_model,\n    sandbox_check_correctness,\n)\nfrom tinker_cookbook.renderers import get_text_content\nfrom tinker_cookbook.renderers.base import Message\nfrom tinker_cookbook.sandbox import SandboxBackend\nfrom tinker_cookbook.tool_use import ToolResult, simple_tool_result, tool\nfrom tinker_cookbook.utils import logtree\n\n\n@dataclass(frozen=True)\nclass DeepcoderTask:\n    \"\"\"A single code task with problem statement and test cases.\"\"\"\n\n    problem: str\n    tests: list[dict[str, Any]]\n    starter_code: str | None = None\n\n\nclass DeepcoderTool:\n    \"\"\"Tool for testing code against a task's test cases.\n\n    Each DeepcoderTool instance is bound to a specific task (its tests).\n    \"\"\"\n\n    def __init__(\n        self,\n        task: DeepcoderTask,\n        sandbox_backend: SandboxBackend | None = None,\n        timeout: int = 6,\n    ):\n        self._task = task\n        self._sandbox_backend = sandbox_backend\n        self._timeout = timeout\n\n    @tool\n    async def check_solution(\n        self,\n        code: Annotated[str, \"Python code implementing the solution.\"],\n    ) -> ToolResult:\n        \"\"\"Execute the proposed solution against the task's test cases.\n\n        Use this to test your code before providing your final answer.\n        \"\"\"\n        try:\n            passed, details = await sandbox_check_correctness(\n                self._task.tests,\n                code,\n                timeout=self._timeout,\n                backend=self._sandbox_backend,\n            )\n            content = json.dumps(\n                {\"passed\": passed, \"details\": details},\n                ensure_ascii=False,\n            )\n            return simple_tool_result(content)\n        except Exception as e:\n            return simple_tool_result(json.dumps({\"error\": str(e), \"passed\": False}))\n\n\n@dataclass\nclass DeepcoderReward:\n    \"\"\"Reward function for code tasks.\n\n    Grades the final answer by extracting code from the last assistant message\n    and running it against the task's tests.\n\n    Formula: format_coef * (has_code_block - 1) + correct\n\n    Called once at episode end with the full message history.\n    \"\"\"\n\n    task: DeepcoderTask\n    sandbox_backend: SandboxBackend | None = None\n    timeout: int = 6\n    format_coef: float = 0.1\n\n    async def __call__(self, history: list[Message]) -> tuple[float, dict[str, float]]:\n        \"\"\"Grade the completed episode by extracting code from final assistant message.\"\"\"\n        # Find the last assistant message\n        final_message = None\n        for msg in reversed(history):\n            if msg.get(\"role\") == \"assistant\":\n                final_message = msg\n                break\n\n        if final_message is None:\n            logtree.log_text(\"No assistant message found in history.\")\n            return 0.0, {\"format\": 0.0, \"correct\": 0.0}\n\n        # Use get_text_content to properly handle thinking models (o1, o3)\n        content = get_text_content(final_message)\n\n        # Extract code from content\n        code = extract_code_from_model(content)\n        has_code_block = code is not None\n\n        # Grade the code by running tests\n        if code is not None:\n            try:\n                passed, _details = await sandbox_check_correctness(\n                    self.task.tests,\n                    code,\n                    timeout=self.timeout,\n                    backend=self.sandbox_backend,\n                )\n                correct = float(passed)\n            except Exception as e:\n                logtree.log_text(f\"Error running tests: {e}\")\n                correct = 0.0\n        else:\n            logtree.log_text(\"No code block detected in response.\")\n            correct = 0.0\n\n        # Reward formula\n        format_score = float(has_code_block)\n        reward = self.format_coef * (format_score - 1.0) + correct\n\n        # Log results\n        logtree.log_text(f\"Problem: {self.task.problem}\")\n        logtree.log_text(f\"Response: {content}\")\n        logtree.log_text(\n            f\"Format Valid: {'✓' if has_code_block else '✗'}, \"\n            f\"Correct: {'✓' if correct > 0 else '✗'}, \"\n            f\"Reward: {reward:.2f}\"\n        )\n\n        return reward, {\"format\": format_score, \"correct\": correct}\n"
  },
  {
    "path": "tinker_cookbook/recipes/code_rl/lcb_utils.py",
    "content": "\"\"\"\nLiveCodeBench testing utilities for sandbox execution.\n\nAdapted from https://github.com/LiveCodeBench/LiveCodeBench\n\nThis module provides TEST_UTIL and TEST_CODE strings that are used to execute\nand validate code submissions in a sandboxed environment. These utilities handle\nboth call-based and standard input/output test cases.\n\"\"\"\n\nTEST_UTIL = r'''\nimport ast\nimport json\nimport sys\nimport faulthandler\nimport platform\n\n# used for debugging to time steps\nfrom datetime import datetime\n\n# to run the solution files we're using a timing based approach\nimport signal\n\nimport numpy as np\n\n# for capturing the stdout\nfrom io import StringIO\n\n# used for testing the code that reads from input\nfrom unittest.mock import patch, mock_open\n\n#from pyext import RuntimeModule\n\nfrom enum import Enum\n\nimport types\n\n\nclass RuntimeModule:\n    @staticmethod\n    def from_string(name, _, code):\n        module = types.ModuleType(name)\n        exec(code, module.__dict__)\n        return module\n\n\ndef truncatefn(s, length=300):\n    assert isinstance(s, str)\n    if len(s) <= length:\n        return s\n\n    return s[: length // 2] + \"...(truncated) ...\" + s[-length // 2 :]\n\n\nclass CODE_TYPE(Enum):\n    call_based = 0\n    standard_input = 1\n\n\n# stuff for setting up signal timer\nclass TimeoutException(Exception):\n    pass\n\n\ndef timeout_handler(signum, frame):\n    print(\"alarm went off\")\n    # return\n    raise TimeoutException\n\n\nsignal.signal(signal.SIGALRM, timeout_handler)\n# timeout = 6  # seconds\n\n\n# used to capture stdout as a list\n# from https://stackoverflow.com/a/16571630/6416660\n# alternative use redirect_stdout() from contextlib\nclass Capturing(list):\n    def __enter__(self):\n        self._stdout = sys.stdout\n        sys.stdout = self._stringio = StringIO()\n        # Make closing the StringIO a no-op\n        self._stringio.close = lambda x: 1\n        return self\n\n    def __exit__(self, *args):\n        self.append(self._stringio.getvalue())\n        del self._stringio  # free up some memory\n        sys.stdout = self._stdout\n\n\ndef only_int_check(val):\n    return isinstance(val, int)\n\n\ndef string_int_check(val):\n    return isinstance(val, str) and val.isdigit()\n\n\ndef combined_int_check(val):\n    return only_int_check(val) or string_int_check(val)\n\n\ndef run_test(sample, test=None, debug=False, timeout=6):\n    \"\"\"\n    if test(generated_code) is not None it'll try to run the code.\n    otherwise it'll just return an input and output pair.\n    \"\"\"\n    # Disable functionalities that can make destructive changes to the test.\n    reliability_guard()\n\n    if debug:\n        print(f\"start = {datetime.now().time()}\")\n\n    try:\n        in_outs = json.loads(sample[\"input_output\"])\n    except ValueError:\n        in_outs = None\n    if in_outs:\n        if in_outs.get(\"fn_name\") is None:\n            which_type = CODE_TYPE.standard_input  # Standard input\n            method_name = None\n        else:\n            which_type = CODE_TYPE.call_based  # Call-based\n            method_name = in_outs[\"fn_name\"]\n\n    if debug:\n        print(f\"loaded input_output = {datetime.now().time()}\")\n\n    if test is None:\n        assert False, \"should not happen: test code is none\"\n        return in_outs, {\"error\": \"no test code provided\"}\n    elif test is not None:\n        results = []\n        sol = \"from string import *\\nfrom re import *\\nfrom datetime import *\\nfrom collections import *\\nfrom heapq import *\\nfrom bisect import *\\nfrom copy import *\\nfrom math import *\\nfrom random import *\\nfrom statistics import *\\nfrom itertools import *\\nfrom functools import *\\nfrom operator import *\\nfrom io import *\\nfrom sys import *\\nfrom json import *\\nfrom builtins import *\\nfrom typing import *\\nimport string\\nimport re\\nimport datetime\\nimport collections\\nimport heapq\\nimport bisect\\nimport copy\\nimport math\\nimport random\\nimport statistics\\nimport itertools\\nimport functools\\nimport operator\\nimport io\\nimport sys\\nimport json\\nsys.setrecursionlimit(6*10**5)\\n\"\n        if debug:\n            print(f\"loading test code = {datetime.now().time()}\")\n\n        if which_type == CODE_TYPE.call_based:\n\n            sol += test\n            if debug:\n                print(f\"sol = {sol}\")\n            signal.alarm(timeout)\n            try:\n                tmp_sol = RuntimeModule.from_string(\"tmp_sol\", \"\", sol)\n                if \"class Solution\" not in test:\n                    tmp = tmp_sol\n                else:\n                    tmp = tmp_sol.Solution()\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                if debug:\n                    print(f\"type 0 compilation error = {e}\")\n                results.append(-2)\n                return results, {\n                    \"error\": repr(e),\n                    \"error_code\": -1,\n                    \"error_message\": \"Compilation Error\",\n                }\n            signal.alarm(0)\n\n        elif which_type == CODE_TYPE.standard_input:\n            # sol\n            # if code has if __name__ == \"__main__\": then remove it\n            try:\n                astree = ast.parse(test)\n                last_block = astree.body[-1]\n                if isinstance(last_block, ast.If):\n                    condition = last_block.test\n                    if ast.unparse(condition).strip() == \"__name__ == '__main__'\":\n                        test = (\n                            ast.unparse(astree.body[:-1])\n                            + \"\\n\"\n                            + ast.unparse(last_block.body)\n                        )\n            except Exception:\n                pass\n\n            tmp_test = test.split(\"\\n\")\n\n            new_test = []\n            for x in tmp_test:\n                if (not x.startswith(\"from \")) and (not x.startswith(\"import \")):\n                    new_test.append(\"\\t\" + x + \"\\n\")\n                else:\n                    new_test.append(x + \"\\n\")\n            tmp_test = new_test\n\n            new_test = \"\"\n            started = False\n            for i in tmp_test:\n                if i.startswith(\"\\t\") and not started:\n                    new_test += \"stdin = sys.stdin\\nstdout = sys.stdout\\n\"\n                    new_test += \"def code():\\n\"\n                    new_test += i\n                    started = True\n                elif started and ((i.startswith(\"from \")) or (i.startswith(\"import \"))):\n                    new_test += \"\\t\" + i\n                else:\n                    new_test += i\n            tmp_test = new_test\n\n            sol += tmp_test\n            if debug:\n                print(f\"sol = {sol}\")\n            method_name = \"code\"\n            signal.alarm(timeout)\n            try:\n                tmp_sol = RuntimeModule.from_string(\"tmp_sol\", \"\", sol)\n                tmp = tmp_sol\n                signal.alarm(0)\n            except Exception as e:\n                signal.alarm(0)\n                if debug:\n                    print(f\"type 1 compilation error = {e}\")\n                results.append(-2)\n                return results, {\n                    \"error\": repr(e),\n                    \"error_code\": -1,\n                    \"error_message\": \"Compilation Error\",\n                }\n            signal.alarm(0)\n        if debug:\n            print(f\"get method = {datetime.now().time()}\")\n\n        try:\n            method = getattr(tmp, method_name)  # get_attr second arg must be str\n        except Exception:\n            signal.alarm(0)\n            e = sys.exc_info()\n            print(f\"unable to get function error = {e}\")\n            results.append(-2)\n            return results, {\n                \"error\": repr(e),\n                \"error_code\": -1,\n                \"error_message\": \"Unable to extract code\",\n            }\n\n        for index, inputs in enumerate(in_outs[\"inputs\"]):\n            raw_inputs = inputs\n            raw_outputs = in_outs[\"outputs\"][index]\n            if which_type == CODE_TYPE.call_based:\n                inputs = [json.loads(line) for line in inputs.split(\"\\n\")]\n                in_outs[\"outputs\"][index] = json.loads(in_outs[\"outputs\"][index])\n\n                truncate_line_size = 300 // (raw_inputs.count(\"\\n\") + 1)\n                raw_inputs = \"\\n\".join(\n                    [\n                        truncatefn(line, truncate_line_size)\n                        for line in raw_inputs.strip().split(\"\\n\")\n                    ]\n                )\n                raw_outputs = truncatefn(raw_outputs, 200)\n            else:\n                raw_inputs = truncatefn(raw_inputs)\n                raw_outputs = truncatefn(raw_outputs, 200)\n            # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)\n            try:\n                if isinstance(inputs[0], dict):\n                    inputs = [{int(k): v for k, v in inputs[0].items()}]\n            except Exception:\n                True\n            try:\n                if isinstance(in_outs[\"outputs\"][index], dict):\n                    in_outs[\"outputs\"][index] = [\n                        {int(k): v for k, v in in_outs[\"outputs\"][index].items()}\n                    ]\n            except Exception:\n                True\n            try:\n                if isinstance(in_outs[\"outputs\"][index][0], dict):\n                    in_outs[\"outputs\"][index] = [\n                        {int(k): v for k, v in in_outs[\"outputs\"][index][0].items()}\n                    ]\n            except Exception:\n                True\n\n            if debug:\n                print(\n                    f\"time: {datetime.now().time()} testing index = {index}  inputs = {inputs}, {type(inputs)}. type = {which_type}\"\n                )\n            if which_type == CODE_TYPE.call_based:  # Call-based\n                signal.alarm(timeout)\n                faulthandler.enable()\n                try:\n                    output = method(*inputs)\n                    raw_true_output = output\n\n                    raw_true_output_copy = json.dumps(output)\n                    raw_true_output_copy = truncatefn(raw_true_output_copy, 200)\n\n                    # ground truth sequences are not tuples\n                    if isinstance(output, tuple):\n                        output = list(output)\n\n                    tmp_result = output == in_outs[\"outputs\"][index]\n                    if (\n                        isinstance(in_outs[\"outputs\"][index], list)\n                        and in_outs[\"outputs\"][index]\n                    ):\n                        tmp_result = tmp_result or (\n                            output == in_outs[\"outputs\"][index][0]\n                        )\n\n                    # ground truth sequences are not tuples\n                    try:\n                        if isinstance(output[0], tuple):\n                            tmp_result = tmp_result or (\n                                [list(x) for x in output]\n                                == in_outs[\"outputs\"][index][0]\n                            )\n                    except Exception:\n                        True\n                    results.append(tmp_result)\n                    if tmp_result != True:\n                        return results, {\n                            \"output\": raw_true_output_copy,\n                            \"expected\": raw_outputs,\n                            \"inputs\": raw_inputs,\n                            \"error_code\": -2,\n                            \"error_message\": \"Wrong Answer\",\n                        }\n                    # reset the alarm\n                    signal.alarm(0)\n                except Exception as e:\n                    signal.alarm(0)\n                    faulthandler.disable()\n                    if debug:\n                        print(\n                            f\"Standard input runtime error or time limit exceeded error = {e}\"\n                        )\n                    results.append(-1)\n                    if \"timeoutexception\" in repr(e).lower():\n                        return results, {\n                            \"error\": repr(e),\n                            \"error_code\": -3,\n                            \"error_message\": \"Time Limit Exceeded\",\n                            \"inputs\": raw_inputs,\n                            \"expected\": raw_outputs,\n                        }\n                    else:\n                        return results, {\n                            \"error\": repr(e),\n                            \"error_code\": -4,\n                            \"error_message\": \"Runtime Error\",\n                            \"inputs\": raw_inputs,\n                            \"expected\": raw_outputs,\n                        }\n                faulthandler.disable()\n                signal.alarm(0)\n                if debug:\n                    print(\n                        f\"outputs = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                    )\n            elif which_type == CODE_TYPE.standard_input:  # Standard input\n                faulthandler.enable()\n                passed = False\n\n                if isinstance(inputs, list):\n                    inputs = \"\\n\".join(inputs)\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    in_outs[\"outputs\"][index] = \"\\n\".join(in_outs[\"outputs\"][index])\n\n                signal.alarm(timeout)\n                with Capturing() as output:\n                    try:\n                        call_method(method, inputs)\n                        # reset the alarm\n                        signal.alarm(0)\n                        passed = True\n                    except Exception as e:\n                        # runtime error or took too long\n                        signal.alarm(0)\n                        print(\n                            f\"Call-based runtime error or time limit exceeded error = {repr(e)}{e}\"\n                        )\n                        results.append(-1)\n                        if \"timeoutexception\" in repr(e).lower():\n                            return results, {\n                                \"error\": repr(e),\n                                \"error_code\": -3,\n                                \"error_message\": \"Time Limit Exceeded\",\n                                \"inputs\": raw_inputs,\n                                \"expected\": raw_outputs,\n                            }\n                        else:\n                            return results, {\n                                \"error\": repr(e),\n                                \"error_code\": -4,\n                                \"error_message\": \"Runtime Error\",\n                                \"inputs\": raw_inputs,\n                                \"expected\": raw_outputs,\n                            }\n                    signal.alarm(0)\n                raw_true_output = output[0]\n                raw_true_output_copy = truncatefn(raw_true_output, 200)\n                output = raw_true_output.splitlines()\n                if not passed:\n                    if debug:\n                        nl = \"\\n\"\n                        if not isinstance(inputs, list):\n                            print(\n                                f\"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                            )\n                        else:\n                            print(\n                                f\"not passed output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                            )\n                    continue\n\n                if passed and debug:\n                    print(\n                        f\"==> output = {output}, test outputs = {in_outs['outputs'][index]}\"\n                    )\n\n                if custom_compare_(output, in_outs[\"outputs\"][index]):\n                    tmp_result = True\n                    results.append(tmp_result)\n                    continue\n\n                # ground truth sequences are expressed as lists not tuples\n                if isinstance(output, tuple):\n                    output = list(output)\n\n                tmp_result = False\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                        if isinstance(output[0], str):\n                            tmp_result = tmp_result or (\n                                [e.strip() for e in output] == in_outs[\"outputs\"][index]\n                            )\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check1 exception = {e}\")\n                    pass\n\n                if tmp_result == True:\n                    results.append(tmp_result)\n                    continue\n\n                # try one more time without \\n\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    for tmp_index, i in enumerate(in_outs[\"outputs\"][index]):\n                        in_outs[\"outputs\"][index][tmp_index] = i.split(\"\\n\")\n                        in_outs[\"outputs\"][index][tmp_index] = [\n                            x.strip() for x in in_outs[\"outputs\"][index][tmp_index] if x\n                        ]\n                else:\n                    in_outs[\"outputs\"][index] = in_outs[\"outputs\"][index].split(\"\\n\")\n                    in_outs[\"outputs\"][index] = list(\n                        filter(len, in_outs[\"outputs\"][index])\n                    )\n                    in_outs[\"outputs\"][index] = list(\n                        map(lambda x: x.strip(), in_outs[\"outputs\"][index])\n                    )\n\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check2 exception = {e}\")\n                    pass\n\n                if tmp_result == True:\n                    results.append(tmp_result)\n                    continue\n\n                # try by converting the output into a split up list too\n                if isinstance(output, list):\n                    output = list(filter(len, output))\n\n                if debug:\n                    nl = \"\\n\"\n                    if not isinstance(inputs, list):\n                        print(\n                            f\"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}\"\n                        )\n                    else:\n                        print(\n                            f\"@1 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]} {tmp_result=}\"\n                        )\n\n                if tmp_result == True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @a\")\n\n                try:\n                    tmp_result = output == [in_outs[\"outputs\"][index]]\n                    if isinstance(in_outs[\"outputs\"][index], list):\n                        tmp_result = tmp_result or (output == in_outs[\"outputs\"][index])\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check3 exception = {e}\")\n                    pass\n\n                if debug:\n                    print(f\"{tmp_result=} @b\")\n\n                try:\n                    all_ints = all(\n                        combined_int_check(e1) and combined_int_check(e2)\n                        for e1, e2 in zip(output, in_outs[\"outputs\"][index])\n                    )\n                    if not all_ints:\n                        if debug:\n                            print(\n                                [\n                                    combined_int_check(e1) and combined_int_check(e2)\n                                    for e1, e2 in zip(output, in_outs[\"outputs\"][index])\n                                ]\n                            )\n                        output_float = [float(e) for e in output]\n                        gt_float = [float(e) for e in in_outs[\"outputs\"][index]]\n                        tmp_result = tmp_result or (\n                            (len(output_float) == len(gt_float))\n                            and np.allclose(output_float, gt_float)\n                        )\n                except Exception as e:\n                    pass\n\n                if debug:\n                    print(f\"{tmp_result=} @c\")\n\n                try:\n                    if isinstance(output[0], list):\n                        all_ints = all(\n                            combined_int_check(e1) and combined_int_check(e2)\n                            for e1, e2 in zip(output[0], in_outs[\"outputs\"][index])\n                        )\n                        if not all_ints:\n                            output_float = [float(e) for e in output[0]]\n                            gt_float = [float(e) for e in in_outs[\"outputs\"][index][0]]\n                            tmp_result = tmp_result or (\n                                (len(output_float) == len(gt_float))\n                                and np.allclose(output_float, gt_float)\n                            )\n                except Exception as e:\n                    pass\n\n                if tmp_result == True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @d\")\n                # try by converting the stuff into split up list\n                if isinstance(in_outs[\"outputs\"][index], list):\n                    for tmp_index, i in enumerate(in_outs[\"outputs\"][index]):\n                        in_outs[\"outputs\"][index][tmp_index] = set(i.split())\n                else:\n                    in_outs[\"outputs\"][index] = set(in_outs[\"outputs\"][index].split())\n\n                if debug:\n                    print(f\"{tmp_result=} @e\")\n\n                try:\n                    tmp_result = output == in_outs[\"outputs\"][index]\n                except Exception as e:\n                    if debug:\n                        print(f\"Failed check4 exception = {e}\")\n                    continue\n\n                if tmp_result == True:\n                    results.append(tmp_result)\n                    continue\n\n                if debug:\n                    print(f\"{tmp_result=} @f\")\n\n                # try by converting the output into a split up list too\n                if isinstance(output, list):\n                    for tmp_index, i in enumerate(output):\n                        output[tmp_index] = i.split()\n                    output = list(filter(len, output))\n                    for tmp_index, i in enumerate(output):\n                        output[tmp_index] = set(i)\n                else:\n                    output = output.split()\n                    output = list(filter(len, output))\n                    output = set(output)\n\n                if debug:\n                    print(f\"{tmp_result=} @g\")\n                # try:\n                #     tmp_result = set(frozenset(s) for s in output) == set(\n                #         frozenset(s) for s in in_outs[\"outputs\"][index]\n                #     )\n                # except Exception as e:\n                #     if debug:\n                #         print(f\"Failed check5 exception = {e}\")\n\n                # if they are all numbers, round so that similar numbers are treated as identical\n                # try:\n                #     all_ints = all(\n                #         combined_int_check(e1) and combined_int_check(e2)\n                #         for e1, e2 in zip(output, in_outs[\"outputs\"][index])\n                #     )\n                #     tmp_result = tmp_result or (\n                #         set(frozenset(round(float(t), 3) for t in s) for s in output)\n                #         == set(\n                #             frozenset(round(float(t), 3) for t in s)\n                #             for s in in_outs[\"outputs\"][index]\n                #         )\n                #     )\n                # except Exception as e:\n                #     if debug:\n                #         print(f\"Failed check6 exception = {e}\")\n\n                if debug:\n                    print(f\"{tmp_result=} @h\")\n\n                if tmp_result == True and debug:\n                    print(\"PASSED\")\n\n                results.append(tmp_result)\n                if tmp_result != True:\n                    return results, {\n                        \"output\": raw_true_output_copy,\n                        \"expected\": raw_outputs,\n                        \"inputs\": raw_inputs,\n                        \"error_code\": -2,\n                        \"error_message\": \"Wrong Answer\",\n                    }\n\n                if debug:\n                    nl = \"\\n\"\n                    if not isinstance(inputs, list):\n                        print(\n                            f\"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs.replace(nl,' new-line ')}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                        )\n                    else:\n                        print(\n                            f\"@2 output = {output}, test outputs = {in_outs['outputs'][index]}, inputs = {inputs}, {type(inputs)}, {output == [in_outs['outputs'][index]]}\"\n                        )\n\n                    print(f\"results = {results}\")\n\n    return results, {}\n\n\ndef custom_compare_(output, ground_truth):\n\n    if isinstance(output, list):\n        output_1 = \"\\n\".join(output)\n        if stripped_string_compare(output_1, ground_truth):\n            return True\n\n    if isinstance(output, list):\n        output_2 = [o.lstrip().rstrip() for o in output]\n        output_2 = \"\\n\".join(output_2)\n        if stripped_string_compare(output_2, ground_truth):\n            return True\n\n    return False\n\n\ndef stripped_string_compare(s1, s2):\n    s1 = s1.lstrip().rstrip()\n    s2 = s2.lstrip().rstrip()\n    return s1 == s2\n\n\ndef call_method(method, inputs):\n\n    if isinstance(inputs, list):\n        inputs = \"\\n\".join(inputs)\n\n    inputs_line_iterator = iter(inputs.split(\"\\n\"))\n\n    # sys.setrecursionlimit(10000)\n\n    # @patch('builtins.input', side_effect=inputs.split(\"\\n\"))\n    @patch(\"builtins.open\", mock_open(read_data=inputs))\n    @patch(\"sys.stdin\", StringIO(inputs))\n    @patch(\"sys.stdin.readline\", lambda *args: next(inputs_line_iterator))\n    @patch(\"sys.stdin.readlines\", lambda *args: inputs.split(\"\\n\"))\n    @patch(\"sys.stdin.read\", lambda *args: inputs)\n    # @patch('sys.stdout.write', print)\n    def _inner_call_method(_method):\n        try:\n            return _method()\n        except SystemExit as e:\n            pass\n        finally:\n            pass\n\n    return _inner_call_method(method)\n\n\ndef reliability_guard(maximum_memory_bytes=None):\n    \"\"\"\n    This disables various destructive functions and prevents the generated code\n    from interfering with the test (e.g. fork bomb, killing other processes,\n    removing filesystem files, etc.)\n    WARNING\n    This function is NOT a security sandbox. Untrusted code, including, model-\n    generated code, should not be blindly executed outside of one. See the\n    Codex paper for more information about OpenAI's code sandbox, and proceed\n    with caution.\n    \"\"\"\n\n    if maximum_memory_bytes is not None:\n        import resource\n\n        resource.setrlimit(\n            resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)\n        )\n        resource.setrlimit(\n            resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)\n        )\n        if not platform.uname().system == \"Darwin\":\n            resource.setrlimit(\n                resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)\n            )\n\n    faulthandler.disable()\n\n    import builtins\n\n    builtins.exit = None\n    builtins.quit = None\n\n    import os\n\n    os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n\n    os.kill = None\n    os.system = None\n    os.putenv = None\n    os.remove = None\n    os.removedirs = None\n    os.rmdir = None\n    os.fchdir = None\n    os.setuid = None\n    os.fork = None\n    os.forkpty = None\n    os.killpg = None\n    os.rename = None\n    os.renames = None\n    os.truncate = None\n    os.replace = None\n    os.unlink = None\n    os.fchmod = None\n    os.fchown = None\n    os.chmod = None\n    os.chown = None\n    os.chroot = None\n    os.fchdir = None\n    os.lchflags = None\n    os.lchmod = None\n    os.lchown = None\n    os.getcwd = None\n    os.chdir = None\n\n    import shutil\n\n    shutil.rmtree = None\n    shutil.move = None\n    shutil.chown = None\n\n    import subprocess\n\n    subprocess.Popen = None  # type: ignore\n\n    __builtins__[\"help\"] = None\n\n    import sys\n\n    sys.modules[\"ipdb\"] = None\n    sys.modules[\"joblib\"] = None\n    sys.modules[\"resource\"] = None\n    sys.modules[\"psutil\"] = None\n    sys.modules[\"tkinter\"] = None\n'''\n\nTEST_CODE = \"\"\"\nimport json\nimport sys\n\nfrom testing_util import run_test\n\nwith open('test_cases.txt', 'r') as fin:\n    sample = json.load(fin)\nwith open('code.py', 'r') as fin:\n    code = fin.read()\n\nresult, metadata = run_test(sample, code, debug=False, timeout=%(timeout)s)\nif any(x != True for x in result):\n    print(metadata)\n    sys.exit(-1)\nelse:\n    sys.exit(0)\n\"\"\"\n\nLCB_SYSTEM_MESSAGE_GENERIC = \"You are an expert Python programmer. You will be given a question (problem specification) and will generate a correct Python program that matches the specification and passes all tests.\"\nLCB_FORMATTING_MESSAGE_WITH_STARTER_CODE = \"You will use the following starter code to write the solution to the problem and enclose your code within delimiters.\"\nLCB_FORMATTING_WITHOUT_STARTER_CODE = \"Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python program runs, it reads the inputs, runs the algorithm and writes output to STDOUT.\"\n\n\ndef fetch_live_code_bench_system_prompt(prompt: str, starter_code: str | None = None) -> str:\n    \"\"\"\n    Build the standard LiveCodeBench system prompt, optionally injecting starter code.\n    \"\"\"\n    prompt = LCB_SYSTEM_MESSAGE_GENERIC + \"\\n\\n\" + prompt\n    if starter_code:\n        prompt += f\"### Format: {LCB_FORMATTING_MESSAGE_WITH_STARTER_CODE}\\n\"\n        prompt += f\"```python\\n{starter_code}\\n```\\n\\n\"\n    else:\n        prompt += f\"### Format: {LCB_FORMATTING_WITHOUT_STARTER_CODE}\\n\"\n        prompt += \"```python\\n# YOUR CODE HERE\\n```\\n\\n\"\n    prompt += \"### Answer: (use the provided format with backticks)\\n\\n\"\n    return prompt\n"
  },
  {
    "path": "tinker_cookbook/recipes/code_rl/sandbox_config/local.yaml",
    "content": "dataset:\n  database:\n    backend:\n      type: none\n    cache:\n      path: memory\n      sources:\n        - type: local\n          path: sandbox/tests/datasets/samples\n  max_runner_concurrency: 32\n  default_dataset_table: code_eval_${dataset_id}\n  registry:\n    - module_path: sandbox.datasets.aider_benchmark\n      class_name: AiderBenchmarkDataset\n      dataset_tables:\n        aider_benchmark_v1: code_eval_aider_benchmark_v1\n    - module_path: sandbox.datasets.autoeval\n      class_name: AutoEvalDataset\n    - module_path: sandbox.datasets.common_oj\n      class_name: CommonOJDataset\n      dataset_tables:\n        code_contests: code_eval_code_contests\n    - module_path: sandbox.datasets.cruxeval\n      class_name: CruxEvalDataset\n      dataset_tables:\n        cruxeval: code_eval_cruxeval\n        cruxeval_x: code_eval_cruxeval_x\n    - module_path: sandbox.datasets.multiple\n      class_name: MultiPLEDataset\n      dataset_tables:\n        multiple_cpp: code_eval_multiple_cpp\n        multiple_ts: code_eval_multiple_ts\n        multiple_sh: code_eval_multiple_sh\n        multiple_cs: code_eval_multiple_cs\n        multiple_go: code_eval_multiple_go\n        multiple_java: code_eval_multiple_java\n        multiple_lua: code_eval_multiple_lua\n        multiple_js: code_eval_multiple_js\n        multiple_php: code_eval_multiple_php\n        multiple_pl: code_eval_multiple_pl\n        multiple_rkt: code_eval_multiple_rkt\n        multiple_r: code_eval_multiple_r\n        multiple_rs: code_eval_multiple_re\n        multiple_scala: code_eval_multiple_scala\n        multiple_swift: code_eval_multiple_swift\n        multiple_rb: code_eval_multiple_rb\n        multiple_d: code_eval_multiple_d\n        multiple_jl: code_eval_multiple_jl\n    - module_path: sandbox.datasets.humaneval\n      class_name: HumanEvalDataset\n      dataset_tables:\n        humaneval_python: code_eval_humaneval_python\n        humaneval_cpp: code_eval_humaneval_cpp\n        humaneval_typescript: code_eval_humaneval_typescript\n        humaneval_bash: code_eval_humaneval_bash\n        humaneval_csharp: code_eval_humaneval_csharp\n        humaneval_go: code_eval_humaneval_go\n        humaneval_java: code_eval_humaneval_java\n        shadow_humaneval_python: code_eval_shadow_humaneval_python\n        bigcodebench: code_eval_bigcodebench\n    - module_path: sandbox.datasets.humanevoeval\n      class_name: EvoEvalDataset\n      dataset_tables:\n        evoeval: code_eval_EvoEval\n    - module_path: sandbox.datasets.live_code_bench\n      class_name: LiveCodeBenchDataset\n      dataset_tables:\n        live_code_bench_v1: code_eval_live_code_bench_v1\n    - module_path: sandbox.datasets.mbpp\n      class_name: MBPPDataset\n      dataset_tables:\n        mbpp: code_eval_mbpp\n    - module_path: sandbox.datasets.mbxp\n      class_name: MBXPDataset\n      dataset_tables:\n        mbxp_v1_en: code_eval_mbxp_v1_en\n        humanevalds_v1_en: code_eval_humanevalds_v1_en\n        oodtest_v1_zh: code_eval_oodtest_v1_zh\n        humanevalds_v2_en: code_eval_humanevalds_v2_en\n        mbxp_v2_en: code_eval_mbxp_v2_en\n    - module_path: sandbox.datasets.mhpp\n      class_name: MHPPDataset\n      dataset_tables:\n        mhpp: code_eval_mhpp\n    - module_path: sandbox.datasets.minif2f\n      class_name: MiniF2FLean4Dataset\n      dataset_tables:\n        minif2f_lean4_test: code_eval_minif2f_lean4_test\n        minif2f_lean4_valid: code_eval_minif2f_lean4_valid\n    - module_path: sandbox.datasets.natural_code_bench\n      class_name: NaturalCodeBenchDataset\n      dataset_tables:\n        ncb_python_zh: code_eval_ncb_python_zh\n        ncb_python_en: code_eval_ncb_python_en\n        ncb_java_zh: code_eval_ncb_java_zh\n        ncb_java_en: code_eval_ncb_java_en\n    - module_path: sandbox.datasets.palmath\n      class_name: PalMathDataset\n      dataset_tables:\n        palmath: code_eval_palmath\n    - module_path: sandbox.datasets.repobench_c\n      class_name: RepobenchCDataset\n      dataset_tables:\n        repobench_c_python: code_eval_repobench_c_python_sampled\n        repobench_c_java: code_eval_repobench_c_java_sampled\n    - module_path: sandbox.datasets.repobench_p\n      class_name: RepobenchPDataset\n      dataset_tables:\n        repobench_p_python: code_eval_repobench_p_python_sampled\n        repobench_p_java: code_eval_repobench_p_java_sampled\n    - module_path: sandbox.datasets.verilog\n      class_name: VerilogDataset\n      dataset_tables:\n        verilogeval_human: code_eval_verilogeval_human\n        verilogeval_machine: code_eval_verilogeval_machine\nsandbox:\n  isolation: none\n  cleanup_process: false\n  restore_bash: false\n  max_concurrency: 34\ncommon:\n  logging_color: true\n"
  },
  {
    "path": "tinker_cookbook/recipes/code_rl/train.py",
    "content": "import asyncio\nimport logging\nfrom datetime import datetime\n\nimport chz\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.recipes.code_rl.code_env import DeepcoderDatasetBuilder\nfrom tinker_cookbook.rl.rollout_strategy import RetryOnFailure\nfrom tinker_cookbook.rl.train import AsyncConfig, Config, main\nfrom tinker_cookbook.sandbox import SandboxBackend\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Command-line configuration for DeepCoder RL training.\"\"\"\n\n    # Model configuration\n    model_name: str = \"meta-llama/Llama-3.1-8B-Instruct\"\n    lora_rank: int = 32\n    renderer_name: str | None = None\n    load_checkpoint_path: str | None = None\n\n    # Data / environment configuration\n    seed: int = 0\n\n    # Training hyperparameters\n    group_size: int = 4\n    groups_per_batch: int = 100\n    learning_rate: float = 1e-5\n    max_tokens: int = 5\n    kl_penalty_coef: float = 0.0\n    num_substeps: int = 1\n\n    # Logging / eval / checkpoints\n    log_dir: str | None = None\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    compute_post_kl: bool = False\n    eval_every: int = 20\n    save_every: int = 20\n\n    # Service configuration\n    base_url: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    # Async rollout configuration\n    max_steps_off_policy: int | None = None\n\n    # Code execution sandbox configuration\n    sandbox_backend: SandboxBackend = SandboxBackend.SANDBOXFUSION\n\n    max_steps: int | None = None\n\n    # Maximum number of times to retry a failed trajectory rollout (container crash,\n    # sandbox flake, etc.). None (default) = crash on any error. 0+ = retry budget.\n    rollout_max_retries: int | None = None\n\n\nasync def cli_main(cli_config: CLIConfig) -> None:\n    renderer_name = await checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n\n    model_tag = cli_config.model_name.replace(\"/\", \"-\")\n    run_name = (\n        f\"deepcoder-{model_tag}-{cli_config.lora_rank}rank-\"\n        f\"{cli_config.learning_rate}lr-{cli_config.group_size}group-\"\n        f\"{cli_config.groups_per_batch}batch-seed{cli_config.seed}-\"\n        f\"{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n    )\n\n    # Set log path\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/code_rl/{run_name}\"\n\n    wandb_name = cli_config.wandb_name or run_name\n\n    dataset_builder = DeepcoderDatasetBuilder(\n        batch_size=cli_config.groups_per_batch,\n        model_name_for_tokenizer=cli_config.model_name,\n        renderer_name=renderer_name,\n        group_size=cli_config.group_size,\n        seed=cli_config.seed,\n        sandbox_backend=cli_config.sandbox_backend,\n    )\n\n    config = Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_builder=dataset_builder,\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        log_path=log_path,\n        base_url=cli_config.base_url,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        compute_post_kl=cli_config.compute_post_kl,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        num_substeps=cli_config.num_substeps,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        async_config=AsyncConfig(\n            max_steps_off_policy=cli_config.max_steps_off_policy,\n            groups_per_batch=cli_config.groups_per_batch,\n        )\n        if cli_config.max_steps_off_policy is not None\n        else None,\n        max_steps=cli_config.max_steps,\n        rollout_error_tolerance=RetryOnFailure(max_retries=cli_config.rollout_max_retries)\n        if cli_config.rollout_max_retries is not None\n        else False,\n    )\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    await main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/distillation/README.md",
    "content": "# Distillation\n\nDistillation refers to a class of methods where a teacher model is supervising the training of a student model, which can often be more efficient than training the student model in isolation. We provide off-policy and on-policy distillation recipes on top of the [OpenThoughts3](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M), [DeepMath](https://huggingface.co/datasets/zwhe99/DeepMath-103K), and [Tulu3](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture)* datasets.\n\nSpecifically, we provide the scripts needed to reproduce our experiments from the [On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation) blog post, which can be run with LoRA using Tinker.\n\n\\* For our post, we regenerated the assistant turns using [Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).\n\n## Distillation for reasoning\n\nOur results can be reproduced by running:\n1. Supervised fine-tuning on [OpenThoughts3](https://huggingface.co/datasets/open-thoughts/OpenThoughts3-1.2M)\n2. On-policy distillation on [DeepMath](https://huggingface.co/datasets/zwhe99/DeepMath-103K)\n\n### Supervised fine-tuning\n\nWe observe an AIME'24 score of ~55% using a rank-128 LoRA after 3000 steps. We use a learning rate of 1e-3 with LoRA and 1e-4 with full fine-tuning.\n\n```bash\npython -m tinker_cookbook.recipes.distillation.off_policy_reasoning \\\n    model_name=Qwen/Qwen3-8B-Base \\\n    learning_rate=1e-3 \\\n    batch_size=128 \\\n    lora_rank=128 \\\n    wandb_project=cookbook_distillation\n```\n\n### On-policy distillation\n\nWe observe an AIME'24 score of ~65% using a rank-128 LoRA after 100 steps. For on-policy distillation experiments, we use a learning rate of 1e-4 with LoRA and 5e-5 with full fine-tuning.\n\n```bash\npython -m tinker_cookbook.recipes.distillation.on_policy_distillation \\\n    model_name=Qwen/Qwen3-8B-Base \\\n    load_checkpoint_path=tinker://4a1939e6-04be-5a77-9e4e-910ccff9f27e:train:0/weights/final \\\n    dataset=deepmath \\\n    learning_rate=1e-4 \\\n    groups_per_batch=512 \\\n    lora_rank=128 \\\n    wandb_project=cookbook_distillation\n```\n\nThis script can also be used to replicate the experiments in our Discussion section, after you have run RL to obtain an appropriate checkpoint for the teacher model.\n\n### Checkpoints\n\nThe results of running the above scripts with various LoRA ranks can be found here:\n\n| Stage | Rank 8 | Rank 32 | Rank 128 |\n|-------|--------|---------|----------|\n| Supervised fine-tuning (SFT) | `tinker://c15f09f1-f225-5f98-bab1-ec8dfac5da2a:train:0/weights/final` | `tinker://b9190d16-c849-51d5-a690-1b5de146a284:train:0/weights/final` | `tinker://4a1939e6-04be-5a77-9e4e-910ccff9f27e:train:0/weights/final` |\n| On-policy distillation | `tinker://4a97bc02-f4d0-5ecd-888a-3e8cc5b0f7f6:train:0/sampler_weights/000080` | `tinker://bfffa2b2-a78f-59be-a2ef-cc9188bfce7e:train:0/sampler_weights/000080` | `tinker://1dd8de47-be86-54d3-9355-ebf80827be26:train:0/sampler_weights/000080` |\n\nSee the on-policy distillation launch command above for an example of how to load the checkpoint path.\n\n## Distillation for personalization\n\nIn this section, we ran:\n1. Supervised fine-tuning on internal documents + resampled Tulu3 data\n2. On-policy distillation on [Tulu3](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) prompts\n\n### On-policy distillation\n\nIn our experiment, we saw [IF-eval](https://huggingface.co/datasets/google/IFEval) recover within approximately 100 steps; we expect similar results in other settings. In order to use this script, you will have to provide your own SFT initialization.\n\n```bash\npython -m tinker_cookbook.recipes.distillation.on_policy_distillation \\\n    model_name=Qwen/Qwen3-8B-Base \\\n    dataset=tulu3 \\\n    learning_rate=1e-4 \\\n    groups_per_batch=64 \\\n    lora_rank=128 \\\n    wandb_project=cookbook_distillation\n```\n\n## Distillation for multi-turn tool use\n\nThe recipes above are single-turn — the student generates one response per prompt and receives a KL signal against the teacher. This works for reasoning and chat, but tool-use tasks require multi-turn interaction: the agent calls tools, observes results, and iterates. This recipe extends on-policy distillation to multi-turn tool-use episodes using Harbor sandbox environments.\n\n### Architecture\n\nMulti-turn distillation reuses three layers of infrastructure:\n\n1. **`tool_use` library** (`tinker_cookbook/tool_use/`) — Generic agent-tool interaction. `@tool` decorator defines tools, `build_agent_tool_env()` creates token-level RL environments from tools + renderer + reward function. `AgentToolMessageEnv` manages the message-level episode loop (append assistant message → execute tool calls → check termination).\n\n2. **`harbor_rl` recipe** (`tinker_cookbook/recipes/harbor_rl/`) — Applies `tool_use` to Harbor sandbox tasks. `HarborBashTool` wraps a sandbox as a `@tool`-decorated bash command. `HarborEnvGroupBuilder` creates sandboxed environments with task-specific grading via `HarborReward`.\n\n3. **Multi-turn distillation** (`tinker_cookbook/recipes/distillation/harbor_multiturn.py`) — `HarborDistillationDatasetBuilder` subclasses `HarborDatasetBuilder`, passing `reward_fn=zero_reward` (always returns 0.0) to override the default `HarborReward`. The only training signal is KL divergence against the teacher.\n\nEnvironment-provided tokens (system prompt, user message, tool responses, assistant headers) are masked out during training — only the student's generated tokens contribute to the loss.\n\n### Data setup\n\nThis recipe uses Harbor sandbox tasks. To get started:\n\n- **Download:** `uvx harbor datasets download terminal-bench@2.0` (lands in `~/.cache/harbor/tasks/`)\n- **Load:** `load_terminal_bench_tasks()` from `tinker_cookbook.recipes.harbor_rl.launch_terminal_bench`\n- **Custom tasks:** any `HarborTask` with `environment/Dockerfile` and `tests/test.sh`\n\nSee `tinker_cookbook/recipes/harbor_rl/README.md` for full details on the HarborTask format and sandbox protocol.\n\n### On-policy distillation (Harbor)\n\n```bash\npython -m tinker_cookbook.recipes.distillation.on_policy_distillation_harbor_multi_turn \\\n    model_name=moonshotai/Kimi-K2-Thinking \\\n    teacher_model=moonshotai/Kimi-K2-Thinking \\\n    max_turns=10 \\\n    group_size=4 \\\n    groups_per_batch=8 \\\n    learning_rate=1e-4 \\\n    lora_rank=8 \\\n    max_tokens=2048 \\\n    max_trajectory_tokens=24576 \\\n    temperature=1.0 \\\n    kl_penalty_coef=1.0 \\\n    sandbox_timeout=600 \\\n    command_timeout=120 \\\n    save_every=5 \\\n    eval_every=5 \\\n    wandb_name=cookbook-multiturn-onpodi\n```\n\n## Additional details\n\n### Reward calculation\n\nIn on-policy distillation, we use an `Environment` that has no rewards (neither correctness nor format). The only supervision comes from minimizing the KL against a teacher model. You can optionally increase `kl_discount_factor` to optimize discounted future KL, but we generally do not observe this to improve performance.\n\n### Distillation with multiple teachers\n\nFor every dataset, we can define a teacher model and batch size (`groups_per_batch`) to use:\n\n```python\n{\n    \"dataset_builder\": RLDatasetBuilder,\n    \"teacher_model\": {\n        \"base_model\": str,  # e.g. \"Qwen/Qwen3-32B\"\n        \"load_checkpoint_path\": str | None  # e.g. \"tinker://<unique_id>/sampler_weights/final\n    },\n    \"groups_per_batch\": int\n}\n```\n\nThe trainer will then sample from each configuration, and concatenate all the individual dataset batches to form the batch for training. This can be used to run multi-teacher distillation, although we do not showcase this in our blog post.\n\n```bash\npython -m tinker_cookbook.recipes.distillation.on_policy_multi_teacher \\\n    learning_rate=1e-4 \\\n    deepmath_groups_per_batch=256 \\\n    tulu3_groups_per_batch=256 \\\n    lora_rank=128 \\\n    wandb_project=cookbook_distillation\n```\n"
  },
  {
    "path": "tinker_cookbook/recipes/distillation/harbor_multiturn.py",
    "content": "\"\"\"Harbor environment for multi-turn on-policy distillation.\n\nProvides a DatasetBuilder that creates harbor sandbox environments with zero\nreward. The only training signal comes from KL divergence against a teacher\nmodel (computed in the training loop).\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\n\nimport chz\n\nfrom tinker_cookbook.recipes.harbor_rl.harbor_env import HarborDataset, HarborDatasetBuilder\nfrom tinker_cookbook.renderers.base import Message\nfrom tinker_cookbook.rl.types import RLDataset\n\nlogger = logging.getLogger(__name__)\n\n\nasync def zero_reward(history: list[Message]) -> tuple[float, dict[str, float]]:\n    \"\"\"Reward function that always returns zero. KL penalty is the only signal.\"\"\"\n    return 0.0, {}\n\n\n@chz.chz\nclass HarborDistillationDatasetBuilder(HarborDatasetBuilder):\n    \"\"\"Build a distillation dataset over Harbor tasks (zero reward, KL only).\"\"\"\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset | None]:\n        train_dataset = HarborDataset(\n            env_group_builders=self._make_env_group_builders(self.group_size),\n            batch_size=self.batch_size,\n        )\n        return train_dataset, None\n"
  },
  {
    "path": "tinker_cookbook/recipes/distillation/harbor_multiturn_test.py",
    "content": "\"\"\"Tests for harbor_multiturn zero-reward distillation.\"\"\"\n\nimport asyncio\nfrom unittest.mock import MagicMock\n\nimport pytest\n\nfrom tinker_cookbook.renderers.base import Message\n\n\nclass TestZeroReward:\n    def test_zero_reward_returns_zero(self):\n        mod = pytest.importorskip(\n            \"tinker_cookbook.recipes.distillation.harbor_multiturn\",\n            reason=\"requires modal\",\n        )\n        result = asyncio.run(mod.zero_reward([Message(role=\"user\", content=\"test\")]))\n        assert result == (0.0, {})\n\n    def test_zero_reward_ignores_history_content(self):\n        mod = pytest.importorskip(\n            \"tinker_cookbook.recipes.distillation.harbor_multiturn\",\n            reason=\"requires modal\",\n        )\n        for history in [[], [Message(role=\"user\", content=\"x\")] * 50]:\n            result = asyncio.run(mod.zero_reward(history))\n            assert result == (0.0, {})\n\n    def test_env_group_builder_compute_group_rewards_returns_zeros(self):\n        mod = pytest.importorskip(\n            \"tinker_cookbook.recipes.distillation.harbor_multiturn\",\n            reason=\"requires modal\",\n        )\n        harbor_env = pytest.importorskip(\n            \"tinker_cookbook.recipes.harbor_rl.harbor_env\",\n            reason=\"requires modal\",\n        )\n        builder = harbor_env.HarborEnvGroupBuilder(\n            task=MagicMock(),\n            model_name=\"test\",\n            renderer_name=\"test\",\n            max_turns=5,\n            group_size=2,\n            reward_fn=mod.zero_reward,\n        )\n        builder._sandboxes = []\n        trajectories = [MagicMock(), MagicMock()]\n        result = asyncio.run(builder.compute_group_rewards(trajectories, []))\n        assert result == [(0.0, {}), (0.0, {})]\n"
  },
  {
    "path": "tinker_cookbook/recipes/distillation/off_policy_reasoning.py",
    "content": "\"\"\"\nSupervised fine-tuning for reasoning tasks using OpenThoughts3.\n\nThis script implements standard supervised learning on the OpenThoughts3 dataset,\nwhich contains reasoning traces with chain-of-thought style responses.\n\nExample usage:\n    python -m tinker_cookbook.recipes.distillation.off_policy_reasoning \\\n        model_name=Qwen/Qwen3-8B-Base \\\n        learning_rate=1e-4 \\\n        batch_size=128 \\\n        lora_rank=128 \\\n        wandb_project=cookbook_distillation\n\"\"\"\n\nimport asyncio\nimport logging\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import cast\n\nimport chz\nimport datasets\nimport tinker\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.renderers import Message, TrainOnWhat\nfrom tinker_cookbook.supervised import train\nfrom tinker_cookbook.supervised.data import (\n    StreamingSupervisedDatasetFromHFDataset,\n    conversation_to_datum,\n)\nfrom tinker_cookbook.supervised.types import (\n    ChatDatasetBuilder,\n    ChatDatasetBuilderCommonConfig,\n    SupervisedDataset,\n)\nfrom tinker_cookbook.utils.lr_scheduling import LRSchedule\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass OpenThoughts3Builder(ChatDatasetBuilder):\n    \"\"\"Builder for OpenThoughts3 dataset with streaming support.\"\"\"\n\n    buffer_size: int = 128 * 3000  # Buffer for shuffle\n    max_prompts: int = 128 * 3000  # Maximum number of prompts to train on\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        # Load streaming dataset\n        ds = datasets.load_dataset(\n            \"open-thoughts/OpenThoughts3-1.2M\", split=\"train\", streaming=True\n        )\n        ds = cast(datasets.IterableDataset, ds)\n\n        # Use train_on_what from common_config if provided, otherwise default to ALL_ASSISTANT_MESSAGES\n        train_on_what = (\n            TrainOnWhat(self.common_config.train_on_what)\n            if self.common_config.train_on_what\n            else TrainOnWhat.ALL_ASSISTANT_MESSAGES\n        )\n\n        def map_fn(row: dict) -> tinker.Datum:\n            # Convert OpenThoughts3 format (from/value) to standard format (role/content)\n            conversations = row.get(\"conversations\", [])\n            messages: list[Message] = [\n                {\n                    \"role\": \"user\" if msg[\"from\"] == \"human\" else \"assistant\",\n                    \"content\": msg[\"value\"],\n                }\n                for msg in conversations\n            ]\n            return conversation_to_datum(\n                messages, self.renderer, self.common_config.max_length, train_on_what\n            )\n\n        train_dataset = StreamingSupervisedDatasetFromHFDataset(\n            hf_dataset=ds,\n            batch_size=self.common_config.batch_size,\n            length=self.max_prompts,\n            map_fn=map_fn,\n            buffer_size=self.buffer_size,\n        )\n\n        # No test dataset for OpenThoughts3\n        return train_dataset, None\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Command-line configuration for SFT on OpenThoughts3.\"\"\"\n\n    # Model configuration\n    model_name: str = \"Qwen/Qwen3-8B-Base\"\n    lora_rank: int = 128\n    renderer_name: str | None = \"qwen3\"\n    load_checkpoint_path: str | None = None\n\n    # Training hyperparameters\n    batch_size: int = 128\n    learning_rate: float = 1e-3\n    lr_schedule: LRSchedule = \"linear\"\n    num_epochs: int = 1\n    max_length: int = 16384\n\n    # Dataset configuration\n    buffer_size: int = 128 * 3000  # Buffer for randomized shuffle\n    max_prompts: int = 128 * 3000  # Maximum number of prompts to train on\n\n    # Logging configuration\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    # Evaluation and checkpointing\n    eval_every: int = 50\n    save_every: int = 50\n\n    # Service configuration\n    base_url: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\ndef cli_main(cli_config: CLIConfig):\n    \"\"\"Convert CLI config to full config and run training.\"\"\"\n\n    # Get renderer name\n    renderer_name = checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n\n    # Create log path if not specified\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n        run_name = Path(log_path).name\n    else:\n        model_name = cli_config.model_name.replace(\"/\", \"-\")\n        run_name = (\n            f\"sft-openthoughts3-{model_name}-\"\n            f\"{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-\"\n            f\"{cli_config.batch_size}batch-{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n        )\n        log_path = f\"/tmp/tinker-examples/distillation/{run_name}\"\n\n    # Create wandb name if not specified\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    # Create dataset builder\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=cli_config.model_name,\n        renderer_name=renderer_name,\n        max_length=cli_config.max_length,\n        batch_size=cli_config.batch_size,\n        train_on_what=None,  # Use default in OpenThoughts3Builder\n    )\n\n    dataset_builder = OpenThoughts3Builder(\n        common_config=common_config,\n        buffer_size=cli_config.buffer_size,\n        max_prompts=cli_config.max_prompts,\n    )\n\n    # Create full config\n    config = train.Config(\n        log_path=log_path,\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        dataset_builder=dataset_builder,\n        evaluator_builders=[],\n        infrequent_evaluator_builders=[],\n        learning_rate=cli_config.learning_rate,\n        lr_schedule=cli_config.lr_schedule,\n        num_epochs=cli_config.num_epochs,\n        base_url=cli_config.base_url,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        lora_rank=cli_config.lora_rank,\n        save_every=cli_config.save_every,\n        eval_every=cli_config.eval_every,\n        max_steps=cli_config.max_steps,\n    )\n\n    # Run training\n    asyncio.run(train.main(config))\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    cli_main(cli_config)\n"
  },
  {
    "path": "tinker_cookbook/recipes/distillation/on_policy_distillation.py",
    "content": "\"\"\"\nOn-policy distillation for reasoning and chat tasks.\n\nThis script implements on-policy distillation where a student model learns from\na teacher model by minimizing KL divergence. No correctness or format rewards\nare used - only KL penalty provides supervision.\n\nExample usage:\n    # For reasoning tasks (DeepMath)\n    python -m tinker_cookbook.recipes.distillation.on_policy_distillation \\\n        model_name=Qwen/Qwen3-8B-Base \\\n        dataset=deepmath \\\n        learning_rate=1e-4 \\\n        groups_per_batch=1024 \\\n        lora_rank=128 \\\n        wandb_project=cookbook_distillation\n\n    # For chat tasks (Tulu3)\n    python -m tinker_cookbook.recipes.distillation.on_policy_distillation \\\n        model_name=Qwen/Qwen3-8B-Base \\\n        dataset=tulu3 \\\n        learning_rate=1e-4 \\\n        groups_per_batch=1024 \\\n        lora_rank=128 \\\n        wandb_project=cookbook_distillation\n\"\"\"\n\nimport asyncio\nimport logging\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any\n\nimport chz\nfrom tinker.types import LossFnType\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.distillation import train_on_policy\nfrom tinker_cookbook.distillation.datasets import (\n    DistillationDatasetConfig,\n    PromptOnlyDatasetBuilder,\n    TeacherConfig,\n)\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Command-line configuration for on-policy distillation.\"\"\"\n\n    # Model configuration\n    model_name: str = \"Qwen/Qwen3-8B-Base\"  # Student model\n    lora_rank: int = 128\n    renderer_name: str | None = None\n    load_checkpoint_path: str | None = None  # Student checkpoint\n\n    # Teacher configuration\n    teacher_model: str = \"Qwen/Qwen3-8B\"\n    teacher_checkpoint: str | None = None\n\n    # Dataset configuration\n    dataset: str = \"deepmath\"  # Options: deepmath, tulu3\n\n    # Training hyperparameters\n    group_size: int = 4  # Number of rollouts per prompt\n    groups_per_batch: int = 1024\n    learning_rate: float = 1e-4\n    max_tokens: int = 4096\n    temperature: float = 1.0\n    kl_penalty_coef: float = 1.0\n    kl_discount_factor: float = 0.0\n\n    # Optimizer configuration\n    num_substeps: int = 1\n\n    # Loss function and configuration.\n    # See https://tinker-docs.thinkingmachines.ai/losses\n    loss_fn: LossFnType = \"importance_sampling\"\n    loss_fn_config: dict[str, Any] | None = None\n\n    # Logging configuration\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    compute_post_kl: bool = False\n\n    # Evaluation and checkpointing\n    eval_every: int = 20\n    save_every: int = 20\n\n    # Service configuration\n    base_url: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\nasync def cli_main(cli_config: CLIConfig):\n    \"\"\"Convert CLI config to full config and run training.\"\"\"\n\n    # Get renderer name\n    renderer_name = await checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n\n    # Create log path if not specified\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        model_name = cli_config.model_name.replace(\"/\", \"-\")\n        run_name = (\n            f\"distill-{cli_config.dataset}-{model_name}-\"\n            f\"{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-\"\n            f\"{cli_config.groups_per_batch}batch-{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n        )\n        log_path = f\"/tmp/tinker-examples/distillation/{run_name}\"\n\n    # Create wandb name if not specified\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = Path(log_path).name\n\n    # Create dataset builder\n    dataset_builder = PromptOnlyDatasetBuilder(\n        dataset_name=cli_config.dataset,\n        groups_per_batch=cli_config.groups_per_batch,\n        group_size=cli_config.group_size,\n        model_name_for_tokenizer=cli_config.model_name,\n        renderer_name=renderer_name,\n    )\n\n    # Create teacher config\n    teacher_config = TeacherConfig(\n        base_model=cli_config.teacher_model,\n        load_checkpoint_path=cli_config.teacher_checkpoint,\n    )\n\n    # Create distillation dataset config\n    dataset_config = DistillationDatasetConfig(\n        dataset_builder=dataset_builder,\n        teacher_config=teacher_config,\n        groups_per_batch=cli_config.groups_per_batch,\n    )\n\n    # Create full config\n    config = train_on_policy.Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_configs=[dataset_config],\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        kl_discount_factor=cli_config.kl_discount_factor,\n        num_substeps=cli_config.num_substeps,\n        loss_fn=cli_config.loss_fn,\n        loss_fn_config=cli_config.loss_fn_config,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        log_path=log_path,\n        base_url=cli_config.base_url,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        compute_post_kl=cli_config.compute_post_kl,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        max_steps=cli_config.max_steps,\n    )\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    # Run training\n    await train_on_policy.main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/distillation/on_policy_distillation_harbor_multi_turn.py",
    "content": "\"\"\"\nMulti-turn on-policy distillation with Harbor sandbox environments.\nYou need to download the tasks from the harbor cache first.\n  uvx harbor datasets download terminal-bench@2.0\n\nThe student interacts with a harbor sandbox over multiple turns (tool calls),\nand the training signal comes from KL divergence against a teacher model.\nEnvironment responses are masked out; only student-generated tokens contribute.\n\nExample usage:\n    python -m tinker_cookbook.recipes.distillation.on_policy_distillation_harbor_multi_turn \\\n        model_name=moonshotai/Kimi-K2-Thinking \\\n        teacher_model=moonshotai/Kimi-K2-Thinking \\\n        max_turns=10 \\\n        group_size=4 \\\n        groups_per_batch=8 \\\n        learning_rate=1e-4 \\\n        lora_rank=8 \\\n        max_tokens=2048 \\\n        max_trajectory_tokens=24576 \\\n        temperature=1.0 \\\n        kl_penalty_coef=1.0 \\\n        sandbox_timeout=600 \\\n        command_timeout=120 \\\n        save_every=5 \\\n        eval_every=5 \\\n        wandb_name=cookbook-multiturn-onpodi\n\"\"\"\n\nimport asyncio\nimport logging\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any\n\nimport chz\nfrom tinker.types import LossFnType\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.distillation import train_on_policy\nfrom tinker_cookbook.distillation.datasets import DistillationDatasetConfig, TeacherConfig\nfrom tinker_cookbook.recipes.distillation.harbor_multiturn import (\n    HarborDistillationDatasetBuilder,\n    zero_reward,\n)\nfrom tinker_cookbook.recipes.harbor_rl.harbor_env import (\n    HarborTask,\n    default_sandbox_factory,\n    load_harbor_tasks,\n)\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Command-line configuration for multi-turn harbor distillation.\"\"\"\n\n    # Student model\n    model_name: str = \"moonshotai/Kimi-K2-Thinking\"\n    lora_rank: int = 32\n    renderer_name: str | None = None\n    load_checkpoint_path: str | None = None\n\n    # Teacher model\n    teacher_model: str = \"moonshotai/Kimi-K2-Thinking\"\n    teacher_checkpoint: str | None = None\n\n    # Harbor environment\n    task_name: str = \"terminal-bench-2.0\"\n    max_turns: int = 10\n    sandbox_timeout: int = 600\n    command_timeout: int = 120\n    max_trajectory_tokens: int = 32 * 1024\n\n    # Training hyperparameters\n    group_size: int = 4\n    groups_per_batch: int = 8\n    learning_rate: float = 1e-4\n    max_tokens: int = 8192\n    temperature: float = 1.0\n    kl_penalty_coef: float = 1.0\n    kl_discount_factor: float = 0.0\n\n    # Optimizer configuration\n    num_substeps: int = 1\n\n    # Loss function\n    loss_fn: LossFnType = \"importance_sampling\"\n    loss_fn_config: dict[str, Any] | None = None\n\n    # Logging\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    # Evaluation and checkpointing\n    eval_every: int = 20\n    save_every: int = 20\n\n    # Service configuration\n    base_url: str | None = None\n\n    max_steps: int | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n\nasync def cli_main(cli_config: CLIConfig, tasks: list[HarborTask]):\n    \"\"\"Load harbor tasks, build distillation config, and run training.\"\"\"\n\n    renderer_name = await checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n\n    # Build log path\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        model_tag = cli_config.model_name.replace(\"/\", \"-\")\n        run_name = (\n            f\"distill-harbor-{model_tag}-{cli_config.lora_rank}rank-\"\n            f\"{cli_config.learning_rate}lr-{cli_config.groups_per_batch}batch-\"\n            f\"{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n        )\n        log_path = str(Path(f\"~/tinker-examples/distillation/{run_name}\").expanduser())\n\n    wandb_name = cli_config.wandb_name or Path(log_path).name\n    logger.info(\"Loaded %d harbor tasks\", len(tasks))\n\n    # Build dataset\n    dataset_builder = HarborDistillationDatasetBuilder(\n        tasks=tasks,\n        batch_size=cli_config.groups_per_batch,\n        group_size=cli_config.group_size,\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        max_turns=cli_config.max_turns,\n        sandbox_timeout=cli_config.sandbox_timeout,\n        command_timeout=cli_config.command_timeout,\n        max_trajectory_tokens=cli_config.max_trajectory_tokens,\n        sandbox_factory=default_sandbox_factory,\n        reward_fn=zero_reward,\n    )\n\n    teacher_config = TeacherConfig(\n        base_model=cli_config.teacher_model,\n        load_checkpoint_path=cli_config.teacher_checkpoint,\n    )\n\n    dataset_config = DistillationDatasetConfig(\n        dataset_builder=dataset_builder,\n        teacher_config=teacher_config,\n        groups_per_batch=cli_config.groups_per_batch,\n    )\n\n    config = train_on_policy.Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_configs=[dataset_config],\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        temperature=cli_config.temperature,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        kl_discount_factor=cli_config.kl_discount_factor,\n        num_substeps=cli_config.num_substeps,\n        loss_fn=cli_config.loss_fn,\n        loss_fn_config=cli_config.loss_fn_config,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        log_path=log_path,\n        base_url=cli_config.base_url,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        max_steps=cli_config.max_steps,\n    )\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    await train_on_policy.main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    tasks = load_harbor_tasks(cli_config.task_name)\n    asyncio.run(cli_main(cli_config, tasks))\n"
  },
  {
    "path": "tinker_cookbook/recipes/distillation/on_policy_multi_teacher.py",
    "content": "\"\"\"\nMulti-teacher on-policy distillation example.\n\nThis script demonstrates on-policy distillation with multiple datasets and\ndifferent teacher models for each dataset. It uses:\n- DeepMath dataset with Qwen3-32B as teacher\n- Tulu3 dataset with Qwen3-235B-A22B-Instruct-2507 as teacher\n- Qwen3-8B as student model\n- qwen3_instruct renderer\n\nExample usage:\n    python -m tinker_cookbook.recipes.distillation.on_policy_multi_teacher \\\n        learning_rate=1e-4 \\\n        deepmath_groups_per_batch=256 \\\n        tulu3_groups_per_batch=256 \\\n        lora_rank=128 \\\n        wandb_project=cookbook_distillation\n\"\"\"\n\nimport asyncio\nimport logging\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any\n\nimport chz\nfrom tinker.types import LossFnType\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.distillation import train_on_policy\nfrom tinker_cookbook.distillation.datasets import (\n    DistillationDatasetConfig,\n    PromptOnlyDatasetBuilder,\n    TeacherConfig,\n)\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Command-line configuration for multi-teacher on-policy distillation.\"\"\"\n\n    # Model configuration\n    model_name: str = \"Qwen/Qwen3-8B\"  # Student model\n    lora_rank: int = 128\n    renderer_name: str | None = None\n    load_checkpoint_path: str | None = None  # Student checkpoint\n\n    # Teacher configurations\n    deepmath_teacher_model: str = \"Qwen/Qwen3-32B\"\n    deepmath_teacher_checkpoint: str | None = None\n    tulu3_teacher_model: str = \"Qwen/Qwen3-235B-A22B-Instruct-2507\"\n    tulu3_teacher_checkpoint: str | None = None\n\n    # Dataset configuration\n    deepmath_groups_per_batch: int = 512\n    tulu3_groups_per_batch: int = 512\n\n    # Training hyperparameters\n    group_size: int = 4  # Number of rollouts per prompt\n    learning_rate: float = 1e-4\n    max_tokens: int = 4096\n    temperature: float = 1.0\n    kl_penalty_coef: float = 1.0\n    kl_discount_factor: float = 0.0\n\n    # Optimizer configuration\n    num_substeps: int = 1\n\n    # Loss function and configuration.\n    # See https://tinker-docs.thinkingmachines.ai/losses\n    loss_fn: LossFnType = \"importance_sampling\"\n    loss_fn_config: dict[str, Any] | None = None\n\n    # Logging configuration\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    compute_post_kl: bool = False\n\n    # Evaluation and checkpointing\n    eval_every: int = 20\n    save_every: int = 20\n\n    # Service configuration\n    base_url: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\nasync def cli_main(cli_config: CLIConfig):\n    \"\"\"Convert CLI config to full config and run training.\"\"\"\n    renderer_name = await checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n\n    # Create log path if not specified\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        model_name = cli_config.model_name.replace(\"/\", \"-\")\n        run_name = (\n            f\"distill-multi-{model_name}-\"\n            f\"{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-\"\n            f\"dm{cli_config.deepmath_groups_per_batch}-t3{cli_config.tulu3_groups_per_batch}-\"\n            f\"{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n        )\n        log_path = f\"/tmp/tinker-examples/distillation/{run_name}\"\n\n    # Create wandb name if not specified\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = Path(log_path).name\n\n    # Create DeepMath dataset builder\n    deepmath_builder = PromptOnlyDatasetBuilder(\n        dataset_name=\"deepmath\",\n        groups_per_batch=cli_config.deepmath_groups_per_batch,\n        group_size=cli_config.group_size,\n        model_name_for_tokenizer=cli_config.model_name,\n        renderer_name=renderer_name,\n    )\n\n    # Create Tulu3 dataset builder\n    tulu3_builder = PromptOnlyDatasetBuilder(\n        dataset_name=\"tulu3\",\n        groups_per_batch=cli_config.tulu3_groups_per_batch,\n        group_size=cli_config.group_size,\n        model_name_for_tokenizer=cli_config.model_name,\n        renderer_name=renderer_name,\n    )\n\n    # Create teacher configs\n    deepmath_teacher_config = TeacherConfig(\n        base_model=cli_config.deepmath_teacher_model,\n        load_checkpoint_path=cli_config.deepmath_teacher_checkpoint,\n    )\n\n    tulu3_teacher_config = TeacherConfig(\n        base_model=cli_config.tulu3_teacher_model,\n        load_checkpoint_path=cli_config.tulu3_teacher_checkpoint,\n    )\n\n    # Create distillation dataset configs\n    deepmath_dataset_config = DistillationDatasetConfig(\n        dataset_builder=deepmath_builder,\n        teacher_config=deepmath_teacher_config,\n        groups_per_batch=cli_config.deepmath_groups_per_batch,\n    )\n\n    tulu3_dataset_config = DistillationDatasetConfig(\n        dataset_builder=tulu3_builder,\n        teacher_config=tulu3_teacher_config,\n        groups_per_batch=cli_config.tulu3_groups_per_batch,\n    )\n\n    # Create full config with both datasets\n    config = train_on_policy.Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_configs=[deepmath_dataset_config, tulu3_dataset_config],\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        kl_discount_factor=cli_config.kl_discount_factor,\n        num_substeps=cli_config.num_substeps,\n        loss_fn=cli_config.loss_fn,\n        loss_fn_config=cli_config.loss_fn_config,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        log_path=log_path,\n        base_url=cli_config.base_url,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        compute_post_kl=cli_config.compute_post_kl,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        max_steps=cli_config.max_steps,\n    )\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    # Run training\n    await train_on_policy.main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/README.md",
    "content": "# Harbor RL\n\n## Installation\n\n```bash\nuv pip install 'tinker-cookbook[modal] @ git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly'\n```\n\nRL training on Harbor formatted tasks (e.g., Terminal Bench 2.0) with sandboxed code execution. An agent gets a bash tool inside a sandboxed container, attempts a task, and receives reward based on test results.\n\n## HarborTask\nHarbor offers a standardized format for SWE/Terminal-Bench style task.\nAdhering to this allows seperation between task creation layer and evaluation/training harness layer.\nWe can download the harbor datasets through `uvx harbor datasets download terminal-bench@2.0`.\nBy default, the task will land in `~/.cache/harbor/tasks/` with the structure\n```\n~/.cache/harbor/tasks/\n  └── <shortuuid(task_id)>/       # deterministic hash for deduplication\n      └── <task_name>/            # human-readable task directory\n          ├── environment/\n          │   └── Dockerfile\n          ├── tests/\n          │   └── test.sh\n          ├── instruction.md\n          ├── task.toml\n          └── solution/\n```\nTo use harbor tasks for training or evaluation, we designed the following interface\n\n```python\n@dataclass(frozen=True)\nclass HarborTask:\n    task_name: str\n    instruction: str\n    task_dir: Path      # must contain environment/Dockerfile and tests/test.sh\n    config: dict[str, Any] = field(default_factory=dict)\n```\n\nYou can load your downloaded tasks (e.g., 89 Terminal-Bench tasks) via `load_harbor_tasks()` in `launch_terminal_bench.py`:\n\n```python\nfrom tinker_cookbook.recipes.harbor_rl.launch_terminal_bench import load_harbor_tasks\n\ntasks = load_harbor_tasks()  # reads from ~/.cache/harbor/tasks/ by default\nprint(f\"Loaded {len(tasks)} tasks\")\nprint(tasks[0].task_name, tasks[0].task_dir)\n```\nThe training environment is implemented against this interface.\nYou can customize your own task as long as they conforms to the interface above.\n\n## Sandbox Protocol and custom backends\n\n### The Protocol\n\n`tinker_cookbook.sandbox.sandbox_interface` defines `SandboxInterface`:\n\n```python\n@runtime_checkable\nclass SandboxInterface(Protocol):\n    async def run_command(self, command: str, workdir: str | None = None, timeout: int = 60, max_output_bytes: int | None = None) -> SandboxResult: ...\n    async def read_file(self, path: str, max_bytes: int | None = None, timeout: int = 60) -> SandboxResult: ...\n    async def write_file(self, path: str, content: str | bytes, executable: bool = False, timeout: int = 60) -> SandboxResult: ...\n    async def send_heartbeat(self) -> None: ...\n    async def cleanup(self) -> None: ...\n```\n\n`ModalSandbox` implements this interface.\n\n### SandboxFactory and injection\n\n`harbor_env.py` defines a factory type and default:\n\n```python\nSandboxFactory = Callable[[modal.Image, int], Awaitable[SandboxInterface]]\n\nasync def default_sandbox_factory(image: modal.Image, timeout: int) -> SandboxInterface:\n    return await ModalSandbox.create(image=image, timeout=timeout)\n```\n\n`cli_main()` accepts an optional `sandbox_factory` parameter. When `None`, it falls back to `default_sandbox_factory` (Modal). The factory flows through: `cli_main` -> `HarborDatasetBuilder` -> `HarborEnvGroupBuilder.make_envs()`.\n\n## Running\n\nFirst, download the Terminal-Bench tasks:\n\n```bash\nuvx harbor datasets download terminal-bench@2.0\n```\n\nThen launch training:\n\n```bash\nuv run python tinker_cookbook/recipes/harbor_rl/scripts/train_terminal_bench.py\n```\n\n## Evaluation\n\nEvaluate a Tinker endpoint on Harbor datasets without training.\n\nDownload datasets:\n```bash\nuvx harbor datasets download terminal-bench@2.0 -o ~/.cache/harbor/tasks/terminal-bench-2.0\nuvx harbor datasets download swebench-verified@1.0 -o ~/.cache/harbor/tasks/swebench-verified-1.0\n```\n\nRun evaluation:\n```bash\nuv run python tinker_cookbook/recipes/harbor_rl/scripts/eval_terminal_bench.py\n```\n\nKey parameters in `EvalConfig`: `max_turns`, `max_tokens`, `temperature`.\n`run_eval()` also accepts `sandbox_factory` for custom sandbox backends and `output_path` to control where results are written (default: `tinker_cookbook/recipes/harbor_rl/scripts/results/<timestamp>/`).\n\nWe evaluated SWE-Bench-Verified-1.0 and Terminal-Bench-2.0 at 32K context length and naive agent harness with no advanced features like context compatification that summarizes the tool calling history.\n\n### Results: Kimi-K2-Thinking (32K context, no compaction)\n\n| Benchmark | Total | PASS | FAIL | ERROR | Pass Rate |\n|-----------|-------|------|------|-------|-----------|\n| SWE-Bench Verified 1.0 | 500 | 46 (9.2%) | 52 (10.4%) | 402 (80.4%) | 9.2% |\n| Terminal-Bench 2.0 | 89 | 18 (20.2%) | 36 (40.4%) | 35 (39.3%) | 20.2% |\n\n**Config**: `max_turns=200, max_tokens=8192, temperature=0.1, sandbox_timeout=3600s`\n\nAll ERRORs are context window overflow (`prompt_tokens + max_tokens > 32768`).\nThese occur when the conversation history exceeds ~24.5K tokens, leaving insufficient room for the 8192 `max_tokens` generation budget.\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/eval.py",
    "content": "\"\"\"\nStandalone evaluation for Harbor tasks.\n\nDownload harbor datasets:\n  uvx harbor datasets download swebench-verified@1.0 -o ~/.cache/harbor/tasks/swebench-verified-1.0\n  uvx harbor datasets download terminal-bench-2.0 -o ~/.cache/harbor/tasks/terminal-bench-2.0\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport json\nimport logging\nimport random\nimport time\nfrom dataclasses import dataclass\nfrom datetime import datetime\nfrom pathlib import Path\n\nimport chz\nimport modal\nimport tinker\n\nfrom tinker_cookbook import model_info, tokenizer_utils\nfrom tinker_cookbook.completers import TinkerTokenCompleter\nfrom tinker_cookbook.display import format_trajectory\nfrom tinker_cookbook.recipes.harbor_rl.harbor_env import (\n    HarborTask,\n    SandboxFactory,\n    _initial_messages,\n    default_sandbox_factory,\n)\nfrom tinker_cookbook.recipes.harbor_rl.harbor_tools import HarborBashTool, HarborReward\nfrom tinker_cookbook.renderers import get_renderer\nfrom tinker_cookbook.renderers.base import Renderer\nfrom tinker_cookbook.rl.rollouts import do_single_rollout\nfrom tinker_cookbook.tool_use import build_agent_tool_env\nfrom tinker_cookbook.utils.ml_log import dump_config\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass EvalConfig:\n    \"\"\"Configuration for Harbor evaluation.\"\"\"\n\n    model_name: str = \"moonshotai/Kimi-K2-Thinking\"\n    output_path: str = \"tinker_cookbook/recipes/harbor_rl/scripts/results\"\n    max_turns: int = 10\n    max_tokens: int = 2048\n    temperature: float = 0.0\n    sandbox_timeout: int = 3600\n    command_timeout: int = 120\n    grader_timeout: int = 60\n    max_tasks: int | None = None\n    checkpoint_url: str | None = None\n    base_url: str | None = None\n    renderer_name: str | None = None\n\n\n@dataclass\nclass TaskResult:\n    task_name: str\n    reward: float\n    reward_details: dict[str, float]\n    turns_used: int\n    time_seconds: float\n    error: str | None = None\n    trajectory_str: str | None = None\n\n\nasync def evaluate_task(\n    task: HarborTask,\n    policy: TinkerTokenCompleter,\n    renderer: Renderer,\n    sandbox_factory: SandboxFactory,\n    config: EvalConfig,\n    results_dir: Path,\n    lock: asyncio.Lock,\n    tokenizer: tokenizer_utils.Tokenizer | None = None,\n) -> TaskResult:\n    \"\"\"Evaluate a single task: create sandbox, run agent loop, grade, cleanup.\n\n    Writes results to files in results_dir as soon as the task completes.\n    \"\"\"\n    start = time.monotonic()\n    env_dir = task.task_dir / \"environment\"\n    dockerfile_path = env_dir / \"Dockerfile\"\n    image = modal.Image.from_dockerfile(path=str(dockerfile_path), context_dir=str(env_dir))\n\n    sandbox = await sandbox_factory(image, config.sandbox_timeout)\n    try:\n        bash_tool = HarborBashTool(sandbox, command_timeout=config.command_timeout)\n        reward_fn = HarborReward(\n            tests_dir=task.task_dir / \"tests\",\n            sandbox=sandbox,\n            grader_timeout=config.grader_timeout,\n        )\n\n        env = build_agent_tool_env(\n            renderer=renderer,\n            tools=[bash_tool.bash],\n            initial_messages=_initial_messages(task, renderer, bash_tool),\n            reward_fn=reward_fn,\n            max_turns=config.max_turns,\n        )\n\n        trajectory = await do_single_rollout(policy, env)\n        reward = sum(t.reward for t in trajectory.transitions)\n        reward_details = trajectory.transitions[-1].metrics if trajectory.transitions else {}\n        turns_used = len(trajectory.transitions)\n        elapsed = time.monotonic() - start\n\n        trajectory_str = (\n            format_trajectory(trajectory, tokenizer, only_last_transition=True)\n            if tokenizer\n            else None\n        )\n\n        result = TaskResult(\n            task_name=task.task_name,\n            reward=reward,\n            reward_details=reward_details,\n            turns_used=turns_used,\n            time_seconds=round(elapsed, 1),\n            trajectory_str=trajectory_str,\n        )\n    except Exception as e:\n        elapsed = time.monotonic() - start\n        logger.error(\"Task %s failed: %s\", task.task_name, e)\n        result = TaskResult(\n            task_name=task.task_name,\n            reward=0.0,\n            reward_details={},\n            turns_used=0,\n            time_seconds=round(elapsed, 1),\n            error=str(e),\n        )\n    finally:\n        try:\n            await sandbox.cleanup()\n        except Exception as e:\n            logger.warning(\"Sandbox cleanup failed for %s: %s\", task.task_name, e)\n\n    # Write results to files immediately\n    status = \"ERROR\" if result.error else (\"PASS\" if result.reward > 0 else \"FAIL\")\n    summary_line = (\n        f\"{result.task_name:<40} {result.reward:>7.1f} {result.turns_used:>6} \"\n        f\"{result.time_seconds:>8.1f} {status:>7}\\n\"\n    )\n\n    async with lock:\n        with open(results_dir / \"asummary.txt\", \"a\") as f:\n            f.write(summary_line)\n\n        if result.error:\n            with open(results_dir / \"aerr.txt\", \"a\") as f:\n                f.write(f\"{'=' * 60}\\n\")\n                f.write(f\"Task: {result.task_name}\\n\")\n                f.write(f\"{'=' * 60}\\n\")\n                f.write(f\"{result.error}\\n\\n\")\n\n        if result.trajectory_str:\n            (results_dir / f\"{result.task_name}.txt\").write_text(result.trajectory_str)\n\n    return result\n\n\nasync def run_eval(\n    config: EvalConfig,\n    tasks: list[HarborTask],\n    sandbox_factory: SandboxFactory = default_sandbox_factory,\n) -> list[TaskResult]:\n    \"\"\"Run evaluation on a list of Harbor tasks.\n\n    Results are written to files in <output_path>/<timestamp>/ as each task completes.\n\n    Args:\n        config: Evaluation configuration.\n        tasks: List of HarborTask to evaluate.\n        sandbox_factory: Factory for creating sandboxes (defaults to Modal).\n\n    Returns:\n        List of per-task results.\n    \"\"\"\n    results_dir = Path(config.output_path) / datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n    results_dir.mkdir(parents=True, exist_ok=True)\n    print(f\"Results dir: {results_dir}\")\n\n    config_dict = dump_config(config)\n    (results_dir / \"config.json\").write_text(json.dumps(config_dict, indent=2))\n\n    lock = asyncio.Lock()\n\n    service_client = tinker.ServiceClient(base_url=config.base_url)\n    if config.checkpoint_url:\n        sampling_client = service_client.create_sampling_client(\n            model_path=config.checkpoint_url,\n            base_model=config.model_name,\n        )\n    else:\n        sampling_client = service_client.create_sampling_client(base_model=config.model_name)\n\n    tokenizer = tokenizer_utils.get_tokenizer(config.model_name)\n    renderer_name = config.renderer_name or model_info.get_recommended_renderer_name(\n        config.model_name\n    )\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    policy = TinkerTokenCompleter(\n        sampling_client=sampling_client,\n        max_tokens=config.max_tokens,\n        temperature=config.temperature,\n    )\n\n    if config.max_tasks is not None:\n        tasks = random.sample(tasks, min(config.max_tasks, len(tasks)))\n\n    logger.info(\"Starting evaluation of %d tasks\", len(tasks))\n\n    task_results = list(\n        await asyncio.gather(\n            *[\n                evaluate_task(\n                    task,\n                    policy,\n                    renderer,\n                    sandbox_factory,\n                    config,\n                    results_dir,\n                    lock,\n                    tokenizer,\n                )\n                for task in tasks\n            ]\n        )\n    )\n\n    return task_results\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/harbor_env.py",
    "content": "\"\"\"Harbor environment, dataset, and dataset builder for RL training.\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport tomllib\nfrom collections.abc import Awaitable, Callable, Sequence\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any\n\nimport chz\nimport modal\n\nfrom tinker_cookbook import model_info, tokenizer_utils\nfrom tinker_cookbook.recipes.harbor_rl.harbor_tools import HarborBashTool, HarborReward\nfrom tinker_cookbook.renderers import get_renderer\nfrom tinker_cookbook.renderers.base import Message, Renderer\nfrom tinker_cookbook.rl.types import Env, EnvGroupBuilder, RLDataset, RLDatasetBuilder\nfrom tinker_cookbook.sandbox import SandboxInterface\nfrom tinker_cookbook.sandbox.modal_sandbox import ModalSandbox\nfrom tinker_cookbook.tool_use import build_agent_tool_env\nfrom tinker_cookbook.tool_use.agent_tool_message_env import RewardFn\n\nlogger = logging.getLogger(__name__)\n\nHARBOR_CACHE_DIR = Path.home() / \".cache\" / \"harbor\" / \"tasks\"\nHARBOR_SYSTEM_PROMPT = (\n    \"You are a skilled software engineer working in a sandboxed environment. \"\n    \"You have access to a bash tool to execute commands. \"\n    \"Complete the task described by the user.\"\n)\n\nSandboxFactory = Callable[[modal.Image, int], Awaitable[SandboxInterface]]\n\n\nasync def default_sandbox_factory(image: modal.Image, timeout: int) -> SandboxInterface:\n    return await ModalSandbox.create(image=image, timeout=timeout)\n\n\n@dataclass(frozen=True)\nclass HarborTask:\n    \"\"\"A single Harbor terminal-bench task.\"\"\"\n\n    task_name: str\n    instruction: str\n    task_dir: Path  # Convention: environment/Dockerfile, tests/test.sh\n    config: dict[str, Any] = field(default_factory=dict)\n\n\ndef load_harbor_tasks(dataset: str) -> list[HarborTask]:\n    \"\"\"Load Harbor tasks from ~/.cache/harbor/tasks/<dataset>/.\"\"\"\n    tasks_dir = HARBOR_CACHE_DIR / dataset\n    tasks: list[HarborTask] = []\n    for uuid_dir in sorted(tasks_dir.iterdir()):\n        (task_dir,) = [d for d in uuid_dir.iterdir() if d.is_dir()]\n        tasks.append(\n            HarborTask(\n                task_name=task_dir.name,\n                instruction=(task_dir / \"instruction.md\").read_text(),\n                task_dir=task_dir,\n                config=tomllib.loads((task_dir / \"task.toml\").read_text()),\n            )\n        )\n    tasks.sort(key=lambda t: t.task_name)\n    return tasks\n\n\ndef _initial_messages(\n    task: HarborTask,\n    renderer: Renderer,\n    bash_tool: HarborBashTool,\n) -> list[Message]:\n    \"\"\"Build initial messages with tool schemas and task instruction.\"\"\"\n    tool_schemas = [bash_tool.bash.to_spec()]\n    prefix = renderer.create_conversation_prefix_with_tools(\n        tools=tool_schemas,\n        system_prompt=HARBOR_SYSTEM_PROMPT,\n    )\n    return prefix + [{\"role\": \"user\", \"content\": task.instruction}]\n\n\nclass HarborEnvGroupBuilder(EnvGroupBuilder):\n    \"\"\"EnvGroupBuilder that creates Harbor environments with Modal sandboxes.\"\"\"\n\n    def __init__(\n        self,\n        task: HarborTask,\n        model_name: str,\n        renderer_name: str | None,\n        max_turns: int,\n        group_size: int,\n        sandbox_timeout: int = 600,\n        command_timeout: int = 120,\n        grader_timeout: int = 60,\n        max_trajectory_tokens: int = 32 * 1024,\n        sandbox_factory: SandboxFactory | None = None,\n        reward_fn: RewardFn | None = None,\n    ):\n        self.task = task\n        self.model_name = model_name\n        self.renderer_name = renderer_name\n        self.max_turns = max_turns\n        self.group_size = group_size\n        self.sandbox_timeout = sandbox_timeout\n        self.command_timeout = command_timeout\n        self.grader_timeout = grader_timeout\n        self.max_trajectory_tokens = max_trajectory_tokens\n        self.sandbox_factory = sandbox_factory or default_sandbox_factory\n        self.reward_fn = reward_fn\n        self._sandboxes: list[SandboxInterface] = []\n\n    async def make_envs(self) -> Sequence[Env]:\n        self._sandboxes = []\n\n        # Build Modal image from the task's Dockerfile\n        env_dir = self.task.task_dir / \"environment\"\n        dockerfile_path = env_dir / \"Dockerfile\"\n        image = modal.Image.from_dockerfile(path=str(dockerfile_path), context_dir=str(env_dir))\n\n        # Create renderer (stateless, shared across envs)\n        tokenizer = tokenizer_utils.get_tokenizer(self.model_name)\n        renderer_name = self.renderer_name or model_info.get_recommended_renderer_name(\n            self.model_name\n        )\n        renderer = get_renderer(renderer_name, tokenizer)\n\n        tests_dir = self.task.task_dir / \"tests\"\n\n        envs = []\n        for _ in range(self.group_size):\n            sandbox = await self.sandbox_factory(image, self.sandbox_timeout)\n            self._sandboxes.append(sandbox)\n\n            bash_tool = HarborBashTool(sandbox, command_timeout=self.command_timeout)\n            reward_fn = self.reward_fn or HarborReward(\n                tests_dir=tests_dir,\n                sandbox=sandbox,\n                grader_timeout=self.grader_timeout,\n            )\n            envs.append(\n                build_agent_tool_env(\n                    renderer=renderer,\n                    tools=[bash_tool.bash],\n                    initial_messages=_initial_messages(self.task, renderer, bash_tool),\n                    reward_fn=reward_fn,\n                    max_turns=self.max_turns,\n                    max_trajectory_tokens=self.max_trajectory_tokens,\n                )\n            )\n        return envs\n\n    async def cleanup(self) -> None:\n        for sandbox in self._sandboxes:\n            try:\n                await sandbox.cleanup()\n            except Exception as e:\n                logger.warning(\"Sandbox cleanup failed: %s\", e)\n        self._sandboxes.clear()\n\n    def logging_tags(self) -> list[str]:\n        return [\"harbor\"]\n\n\nclass HarborDataset(RLDataset):\n    \"\"\"Dataset that produces batches of HarborEnvGroupBuilders.\"\"\"\n\n    def __init__(\n        self,\n        env_group_builders: list[HarborEnvGroupBuilder],\n        batch_size: int,\n    ):\n        self.env_group_builders = env_group_builders\n        self.batch_size = batch_size\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        start = index * self.batch_size\n        end = start + self.batch_size\n        return self.env_group_builders[start:end]\n\n    def __len__(self) -> int:\n        return (len(self.env_group_builders) + self.batch_size - 1) // self.batch_size\n\n\n@chz.chz\nclass HarborDatasetBuilder(RLDatasetBuilder):\n    \"\"\"Build an RL dataset over Harbor tasks.\"\"\"\n\n    tasks: list[HarborTask]\n    batch_size: int\n    group_size: int\n    model_name: str\n    renderer_name: str | None = None\n    max_turns: int = 10\n    sandbox_timeout: int = 600\n    command_timeout: int = 120\n    grader_timeout: int = 60\n    max_trajectory_tokens: int = 32 * 1024\n    sandbox_factory: SandboxFactory | None = None\n    reward_fn: RewardFn | None = None\n\n    def _make_env_group_builders(self, group_size: int) -> list[HarborEnvGroupBuilder]:\n        return [\n            HarborEnvGroupBuilder(\n                task=task,\n                model_name=self.model_name,\n                renderer_name=self.renderer_name,\n                max_turns=self.max_turns,\n                group_size=group_size,\n                sandbox_timeout=self.sandbox_timeout,\n                command_timeout=self.command_timeout,\n                grader_timeout=self.grader_timeout,\n                max_trajectory_tokens=self.max_trajectory_tokens,\n                sandbox_factory=self.sandbox_factory,\n                reward_fn=self.reward_fn,\n            )\n            for task in self.tasks\n        ]\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset | None]:\n        train_dataset = HarborDataset(\n            env_group_builders=self._make_env_group_builders(self.group_size),\n            batch_size=self.batch_size,\n        )\n        eval_dataset = HarborDataset(\n            env_group_builders=self._make_env_group_builders(group_size=1),\n            batch_size=self.batch_size,\n        )\n        return train_dataset, eval_dataset\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/harbor_tools.py",
    "content": "\"\"\"Harbor bash tool and reward function for RL training.\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Annotated\n\nfrom tinker_cookbook.renderers.base import Message\nfrom tinker_cookbook.sandbox import SandboxInterface\nfrom tinker_cookbook.tool_use import ToolResult, simple_tool_result, tool\n\nlogger = logging.getLogger(__name__)\n\nMAX_OUTPUT_CHARS = 16384\n\n\nclass HarborBashTool:\n    \"\"\"Bash tool that executes commands in a sandbox.\n\n    Wraps a SandboxInterface as a tinker_cookbook Tool via the @tool decorator.\n    \"\"\"\n\n    def __init__(self, sandbox: SandboxInterface, command_timeout: int = 120) -> None:\n        self._sandbox = sandbox\n        self._command_timeout = command_timeout\n\n    @tool\n    async def bash(\n        self,\n        command: Annotated[str, \"The bash command to execute.\"],\n    ) -> ToolResult:\n        \"\"\"Execute a bash command in the sandbox environment.\n\n        Use this to run shell commands, install packages, edit files, etc.\n        \"\"\"\n        result = await self._sandbox.run_command(\n            command, workdir=\"/\", timeout=self._command_timeout, max_output_bytes=MAX_OUTPUT_CHARS\n        )\n        stdout = result.stdout[:MAX_OUTPUT_CHARS]\n        stderr = result.stderr[:MAX_OUTPUT_CHARS]\n        output = json.dumps({\"exit_code\": result.exit_code, \"stdout\": stdout, \"stderr\": stderr})\n        return simple_tool_result(output)\n\n\n@dataclass\nclass HarborReward:\n    \"\"\"Reward function for Harbor tasks.\n\n    Grades by uploading test files to the sandbox, running test.sh,\n    and parsing reward from /logs/verifier/reward.txt or reward.json.\n\n    Called once at episode end with the full message history.\n    \"\"\"\n\n    tests_dir: Path\n    sandbox: SandboxInterface\n    grader_timeout: int = 60\n\n    async def __call__(self, history: list[Message]) -> tuple[float, dict[str, float]]:\n        \"\"\"Grade the completed episode by running test.sh in the sandbox.\"\"\"\n        try:\n            # 1. Upload test files to /tests/ in sandbox\n            await self._upload_tests()\n\n            # 2. Create log directory and run test.sh\n            # Run from /root (not /) because test.sh checks if PWD=/ and exits early\n            await self.sandbox.run_command(\"mkdir -p /logs/verifier\", workdir=\"/root\")\n            result = await self.sandbox.run_command(\n                \"bash /tests/test.sh\",\n                workdir=\"/root\",\n                timeout=self.grader_timeout,\n            )\n            logger.info(\"test.sh completed with exit_code=%d\", result.exit_code)\n            if result.stdout:\n                logger.debug(\"test.sh stdout: %s\", result.stdout[:500])\n            if result.stderr:\n                logger.debug(\"test.sh stderr: %s\", result.stderr[:500])\n\n            # 3. Parse reward\n            reward = await self._parse_reward()\n            return reward, {\"reward\": reward, \"test_passed\": float(reward > 0)}\n\n        except Exception as e:\n            logger.error(\"Harbor grading failed: %s\", e)\n            return 0.0, {\"reward\": 0.0, \"test_passed\": 0.0, \"grading_error\": 1.0}\n\n    async def _upload_tests(self) -> None:\n        \"\"\"Upload test files from local tests_dir to /tests/ in sandbox.\"\"\"\n        await self.sandbox.run_command(\"mkdir -p /tests\", workdir=\"/\")\n        for file_path in self.tests_dir.iterdir():\n            if not file_path.is_file():\n                continue\n            content = file_path.read_text()\n            target = f\"/tests/{file_path.name}\"\n            await self.sandbox.write_file(target, content, executable=(file_path.suffix == \".sh\"))\n\n    async def _parse_reward(self) -> float:\n        \"\"\"Parse reward from /logs/verifier/reward.txt or reward.json.\"\"\"\n        # Try reward.txt first\n        result = await self.sandbox.read_file(\"/logs/verifier/reward.txt\")\n        if result.exit_code == 0 and result.stdout.strip():\n            reward = float(result.stdout.strip())\n            logger.info(\"Parsed reward from reward.txt: %s\", reward)\n            return reward\n\n        # Try reward.json\n        result = await self.sandbox.read_file(\"/logs/verifier/reward.json\")\n        if result.exit_code == 0 and result.stdout.strip():\n            data = json.loads(result.stdout)\n            reward = float(data.get(\"reward\", 0.0))\n            logger.info(\"Parsed reward from reward.json: %s\", reward)\n            return reward\n\n        logger.warning(\"No reward file found at /logs/verifier/\")\n        return 0.0\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/harbor_tools_test.py",
    "content": "\"\"\"Unit tests for HarborReward, HarborBashTool, and HarborEnvGroupBuilder.\"\"\"\n\nimport asyncio\nimport json\nimport pickle\nfrom pathlib import Path\n\nfrom tinker_cookbook.recipes.harbor_rl.harbor_env import HarborEnvGroupBuilder, HarborTask\nfrom tinker_cookbook.recipes.harbor_rl.harbor_tools import (\n    MAX_OUTPUT_CHARS,\n    HarborBashTool,\n    HarborReward,\n)\nfrom tinker_cookbook.sandbox.sandbox_interface import SandboxResult\nfrom tinker_cookbook.tool_use.types import ToolInput\n\n\nclass FakeSandbox:\n    \"\"\"In-memory sandbox for testing.\"\"\"\n\n    def __init__(self) -> None:\n        self.files: dict[str, str] = {}\n        self.executable_files: set[str] = set()\n        self.commands_run: list[str] = []\n        self._command_results: dict[str, SandboxResult] = {}\n        self._default_result = SandboxResult(stdout=\"\", stderr=\"\", exit_code=0)\n\n    @property\n    def sandbox_id(self) -> str:\n        return \"fake-sandbox\"\n\n    def set_command_result(self, command: str, result: SandboxResult) -> None:\n        self._command_results[command] = result\n\n    async def send_heartbeat(self) -> None:\n        pass\n\n    async def run_command(\n        self,\n        command: str,\n        workdir: str | None = None,\n        timeout: int = 60,\n        max_output_bytes: int | None = None,\n    ) -> SandboxResult:\n        self.commands_run.append(command)\n        if command in self._command_results:\n            return self._command_results[command]\n        return self._default_result\n\n    async def read_file(\n        self, path: str, max_bytes: int | None = None, timeout: int = 60\n    ) -> SandboxResult:\n        if path in self.files:\n            return SandboxResult(stdout=self.files[path], stderr=\"\", exit_code=0)\n        return SandboxResult(stdout=\"\", stderr=f\"No such file: {path}\", exit_code=1)\n\n    async def write_file(\n        self, path: str, content: str | bytes, executable: bool = False, timeout: int = 60\n    ) -> SandboxResult:\n        self.files[path] = content if isinstance(content, str) else content.decode()\n        if executable:\n            self.executable_files.add(path)\n        return SandboxResult(stdout=\"\", stderr=\"\", exit_code=0)\n\n    async def cleanup(self) -> None:\n        pass\n\n\n# ---------------------------------------------------------------------------\n# HarborReward tests\n# ---------------------------------------------------------------------------\n\n\nclass TestHarborReward:\n    def _make_reward(self, tmp_path: Path, sandbox: FakeSandbox, **kwargs) -> HarborReward:\n        return HarborReward(tests_dir=tmp_path, sandbox=sandbox, **kwargs)\n\n    def test_reward_from_txt(self, tmp_path: Path) -> None:\n        sandbox = FakeSandbox()\n        sandbox.files[\"/logs/verifier/reward.txt\"] = \"1.0\"\n        reward_fn = self._make_reward(tmp_path, sandbox)\n\n        reward, info = asyncio.run(reward_fn([]))\n        assert reward == 1.0\n        assert info == {\"reward\": 1.0, \"test_passed\": 1.0}\n\n    def test_reward_from_json(self, tmp_path: Path) -> None:\n        sandbox = FakeSandbox()\n        sandbox.files[\"/logs/verifier/reward.json\"] = json.dumps({\"reward\": 0.75})\n        reward_fn = self._make_reward(tmp_path, sandbox)\n\n        reward, info = asyncio.run(reward_fn([]))\n        assert reward == 0.75\n        assert info == {\"reward\": 0.75, \"test_passed\": 1.0}\n\n    def test_no_reward_file(self, tmp_path: Path) -> None:\n        sandbox = FakeSandbox()\n        reward_fn = self._make_reward(tmp_path, sandbox)\n\n        reward, info = asyncio.run(reward_fn([]))\n        assert reward == 0.0\n        assert info == {\"reward\": 0.0, \"test_passed\": 0.0}\n\n    def test_zero_reward(self, tmp_path: Path) -> None:\n        sandbox = FakeSandbox()\n        sandbox.files[\"/logs/verifier/reward.txt\"] = \"0.0\"\n        reward_fn = self._make_reward(tmp_path, sandbox)\n\n        reward, info = asyncio.run(reward_fn([]))\n        assert reward == 0.0\n        assert info[\"test_passed\"] == 0.0\n\n    def test_grading_error(self, tmp_path: Path) -> None:\n        \"\"\"Sandbox exception during grading returns 0 reward with error flag.\"\"\"\n\n        class ExplodingSandbox(FakeSandbox):\n            async def run_command(\n                self,\n                command: str,\n                workdir: str | None = None,\n                timeout: int = 60,\n                max_output_bytes: int | None = None,\n            ) -> SandboxResult:\n                raise RuntimeError(\"sandbox died\")\n\n        sandbox = ExplodingSandbox()\n        reward_fn = self._make_reward(tmp_path, sandbox)\n\n        reward, info = asyncio.run(reward_fn([]))\n        assert reward == 0.0\n        assert info[\"grading_error\"] == 1.0\n\n    def test_upload_tests(self, tmp_path: Path) -> None:\n        \"\"\"Files are uploaded; .sh files marked executable; subdirs skipped.\"\"\"\n        (tmp_path / \"test.sh\").write_text(\"#!/bin/bash\\necho ok\")\n        (tmp_path / \"helper.py\").write_text(\"print('hi')\")\n        (tmp_path / \"subdir\").mkdir()\n\n        sandbox = FakeSandbox()\n        sandbox.files[\"/logs/verifier/reward.txt\"] = \"1.0\"\n        reward_fn = self._make_reward(tmp_path, sandbox)\n\n        asyncio.run(reward_fn([]))\n\n        assert \"/tests/test.sh\" in sandbox.files\n        assert \"/tests/helper.py\" in sandbox.files\n        assert \"/tests/test.sh\" in sandbox.executable_files\n        assert \"/tests/helper.py\" not in sandbox.executable_files\n        # subdir should not be uploaded\n        assert not any(\"subdir\" in p for p in sandbox.files if p.startswith(\"/tests/\"))\n\n\n# ---------------------------------------------------------------------------\n# HarborBashTool tests\n# ---------------------------------------------------------------------------\n\n\nclass TestHarborBashTool:\n    def test_bash_tool_basic(self) -> None:\n        sandbox = FakeSandbox()\n        sandbox.set_command_result(\n            \"echo hello\",\n            SandboxResult(stdout=\"hello\\n\", stderr=\"\", exit_code=0),\n        )\n        tool_obj = HarborBashTool(sandbox)\n\n        result = asyncio.run(tool_obj.bash.run(ToolInput(arguments={\"command\": \"echo hello\"})))\n        content = result.messages[0][\"content\"]\n        assert isinstance(content, str)\n        output = json.loads(content)\n        assert output[\"exit_code\"] == 0\n        assert output[\"stdout\"] == \"hello\\n\"\n        assert output[\"stderr\"] == \"\"\n\n    def test_bash_tool_truncation(self) -> None:\n        long_stdout = \"x\" * (MAX_OUTPUT_CHARS + 100)\n        long_stderr = \"e\" * (MAX_OUTPUT_CHARS + 100)\n        sandbox = FakeSandbox()\n        sandbox.set_command_result(\n            \"big_cmd\",\n            SandboxResult(stdout=long_stdout, stderr=long_stderr, exit_code=1),\n        )\n        tool_obj = HarborBashTool(sandbox)\n\n        result = asyncio.run(tool_obj.bash.run(ToolInput(arguments={\"command\": \"big_cmd\"})))\n        content = result.messages[0][\"content\"]\n        assert isinstance(content, str)\n        output = json.loads(content)\n        assert len(output[\"stdout\"]) == MAX_OUTPUT_CHARS\n        assert len(output[\"stderr\"]) == MAX_OUTPUT_CHARS\n        assert output[\"exit_code\"] == 1\n\n\n# ---------------------------------------------------------------------------\n# HarborEnvGroupBuilder pickle tests\n# ---------------------------------------------------------------------------\n\n\nclass TestHarborEnvGroupBuilderPickle:\n    def _make_task(self, tmp_path: Path) -> HarborTask:\n        return HarborTask(\n            task_name=\"test-task\",\n            instruction=\"Fix the bug\",\n            task_dir=tmp_path,\n            config={\"difficulty\": \"easy\"},\n        )\n\n    def test_pickle_roundtrip(self, tmp_path: Path) -> None:\n        \"\"\"HarborEnvGroupBuilder survives pickle/unpickle with default sandbox_factory.\"\"\"\n        builder = HarborEnvGroupBuilder(\n            task=self._make_task(tmp_path),\n            model_name=\"meta-llama/Llama-3.1-8B-Instruct\",\n            renderer_name=\"llama3\",\n            max_turns=5,\n            group_size=2,\n        )\n\n        restored = pickle.loads(pickle.dumps(builder))\n\n        assert restored.task == builder.task\n        assert restored.model_name == builder.model_name\n        assert restored.renderer_name == builder.renderer_name\n        assert restored.max_turns == builder.max_turns\n        assert restored.group_size == builder.group_size\n        assert restored.sandbox_timeout == builder.sandbox_timeout\n        assert restored.command_timeout == builder.command_timeout\n        assert restored.grader_timeout == builder.grader_timeout\n        assert restored.max_trajectory_tokens == builder.max_trajectory_tokens\n        assert restored.sandbox_factory is builder.sandbox_factory\n\n    def test_pickle_with_custom_params(self, tmp_path: Path) -> None:\n        \"\"\"Non-default scalar parameters survive pickle roundtrip.\"\"\"\n        builder = HarborEnvGroupBuilder(\n            task=self._make_task(tmp_path),\n            model_name=\"meta-llama/Llama-3.1-8B-Instruct\",\n            renderer_name=\"llama3\",\n            max_turns=5,\n            group_size=2,\n            sandbox_timeout=300,\n            command_timeout=60,\n            grader_timeout=30,\n            max_trajectory_tokens=16 * 1024,\n        )\n\n        restored = pickle.loads(pickle.dumps(builder))\n\n        assert restored.sandbox_timeout == 300\n        assert restored.command_timeout == 60\n        assert restored.grader_timeout == 30\n        assert restored.max_trajectory_tokens == 16 * 1024\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/scripts/.gitignore",
    "content": "results/\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/scripts/eval_terminal_bench.py",
    "content": "\"\"\"\nLoad Terminal-Bench tasks from the Harbor cache and run evaluation.\n\nuv run python tinker_cookbook/recipes/harbor_rl/scripts/eval_terminal_bench.py\n\n\"\"\"\n\nimport asyncio\n\nfrom tinker_cookbook.recipes.harbor_rl.eval import EvalConfig, run_eval\nfrom tinker_cookbook.recipes.harbor_rl.harbor_env import default_sandbox_factory, load_harbor_tasks\n\nif __name__ == \"__main__\":\n    config = EvalConfig(\n        max_turns=200,\n        temperature=0.1,\n        max_tokens=8192,\n    )\n    tasks = load_harbor_tasks(\"terminal-bench-2.0\")\n    asyncio.run(run_eval(config, tasks, sandbox_factory=default_sandbox_factory))\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/scripts/train_terminal_bench.py",
    "content": "\"\"\"\nLoad Terminal-Bench tasks from the Harbor cache and launch RL training.\n\nuv run python tinker_cookbook/recipes/harbor_rl/scripts/train_terminal_bench.py\n\n\"\"\"\n\nimport asyncio\n\nfrom tinker_cookbook.recipes.harbor_rl.harbor_env import default_sandbox_factory, load_harbor_tasks\nfrom tinker_cookbook.recipes.harbor_rl.train import CLIConfig, cli_main\n\nif __name__ == \"__main__\":\n    cli_config = CLIConfig()\n    tasks = load_harbor_tasks(\"terminal-bench-2.0\")\n    asyncio.run(cli_main(cli_config, tasks, sandbox_factory=default_sandbox_factory))\n"
  },
  {
    "path": "tinker_cookbook/recipes/harbor_rl/train.py",
    "content": "\"\"\"CLI entry point for Harbor RL training.\"\"\"\n\nimport logging\nfrom datetime import datetime\n\nimport chz\n\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.recipes.harbor_rl.harbor_env import (\n    HarborDatasetBuilder,\n    HarborTask,\n    SandboxFactory,\n)\nfrom tinker_cookbook.rl.train import AsyncConfig, Config, main\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Command-line configuration for Harbor RL training.\"\"\"\n\n    # Model configuration\n    model_name: str = \"moonshotai/Kimi-K2-Thinking\"\n    lora_rank: int = 32\n    renderer_name: str | None = None\n    load_checkpoint_path: str | None = None\n    max_tokens: int = 8192\n    temperature: float = 1.0\n\n    # Environment configuration\n    max_turns: int = 10\n    sandbox_timeout: int = 3600\n    command_timeout: int = 120\n    grader_timeout: int = 60\n\n    # Training hyperparameters\n    group_size: int = 4\n    groups_per_batch: int = 8\n    learning_rate: float = 1e-5\n    kl_penalty_coef: float = 0.0\n    num_substeps: int = 1\n\n    # Logging / eval / checkpoints\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    eval_every: int = 5\n    save_every: int = 5\n\n    # Service configuration\n    base_url: str | None = None\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    # Async rollout configuration\n    max_steps_off_policy: int | None = None\n\n    max_steps: int | None = None\n\n\nasync def cli_main(\n    cli_config: CLIConfig,\n    tasks: list[HarborTask],\n    sandbox_factory: SandboxFactory | None = None,\n) -> None:\n    renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(\n        cli_config.model_name\n    )\n\n    model_tag = cli_config.model_name.replace(\"/\", \"-\")\n    run_name = (\n        f\"harbor-{model_tag}-{cli_config.lora_rank}rank-\"\n        f\"{cli_config.learning_rate}lr-{cli_config.group_size}group-\"\n        f\"{cli_config.groups_per_batch}batch-\"\n        f\"{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n    )\n\n    log_path = cli_config.log_path or f\"/tmp/tinker-examples/harbor_rl/{run_name}\"\n    wandb_name = cli_config.wandb_name or run_name\n\n    dataset_builder = HarborDatasetBuilder(\n        tasks=tasks,\n        batch_size=cli_config.groups_per_batch,\n        group_size=cli_config.group_size,\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        max_turns=cli_config.max_turns,\n        sandbox_timeout=cli_config.sandbox_timeout,\n        command_timeout=cli_config.command_timeout,\n        grader_timeout=cli_config.grader_timeout,\n        sandbox_factory=sandbox_factory,\n    )\n\n    config = Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_builder=dataset_builder,\n        model_name=cli_config.model_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        temperature=cli_config.temperature,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        log_path=log_path,\n        base_url=cli_config.base_url,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        num_substeps=cli_config.num_substeps,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        async_config=AsyncConfig(\n            max_steps_off_policy=cli_config.max_steps_off_policy,\n            groups_per_batch=cli_config.groups_per_batch,\n        )\n        if cli_config.max_steps_off_policy is not None\n        else None,\n        max_steps=cli_config.max_steps,\n    )\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    await main(config)\n"
  },
  {
    "path": "tinker_cookbook/recipes/math_rl/README.md",
    "content": "# Using Reinforcement Learning to Solve Math Problems\n\nMath problems have been the most active testbed for RL with LLMs. This recipe collects environments and grading functions that allow you to test on several popular math datasets.\n\n## Installation\n\n```bash\nuv pip install 'tinker-cookbook[math-rl] @ git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly'\n```\n\n## RL on arithmetic.\n\nTrivial, but runs fast enough that you can see it learn. Reward should go from 0.66 to 1 in the first few steps.\n\n```bash\npython -m tinker_cookbook.recipes.math_rl.train model_name=\"meta-llama/Llama-3.2-1B\" group_size=4 groups_per_batch=100 learning_rate=1e-4\n```\n\n## RL on MATH dataset.\n\n```bash\npython -m tinker_cookbook.recipes.math_rl.train env=math model_name=\"Qwen/Qwen3-8B\" group_size=16 groups_per_batch=64 learning_rate=2e-5 max_tokens=512\n```\n\nAfter 180 steps of training, we observe `\"test/env/all/correct\": 0.767578125`, which is logged to `/tmp/tinker-examples/math_rl/math-Qwen_Qwen3-8B-32rank-2e-05lr-${DATE}/metrics.jsonl`.\n\n```\n<|im_start|>user\nThe numbers 2, 3, 5, 7, 11, 13 are arranged in a multiplication table, with three along the top and the other three down the left.  The multiplication table is completed a\nnd the sum of the nine entries is tabulated.  What is the largest possible sum of the nine entries?\n\n\\[\n\\begin{array}{c||c|c|c|}\n\\times & a & b & c \\\\ \\hline \\hline\nd & & & \\\\ \\hline\ne & & & \\\\ \\hline\nf & & & \\\\ \\hline\n\\end{array}\n\\] Write your answer in \\boxed{} format.<|im_end|>\n<|im_start|>assistant\n<think>\nOkay, so we have to arrange the primes 2,3,5,7,11,13 into a 3x3 multiplication table. The sum of all entries is (a+b+c)*(d+e+f) - because the sum of the products is the pr\noduct of the sums. To maximize the total sum, we need to maximize the product of the sums of the two groups. So, split the numbers into two groups of three, one group for\nthe top row (a,b,c) and one for the left column (d,e,f). The product (sum1)*(sum2) is maximized when the sums are as close as possible? Or maybe when one is as large as po\nssible and the other as small? Wait, for a fixed sum, the product is maximized when they're equal. But since the total sum of all numbers is 2+3+5+7+11+13= 41. So sum1 + s\num2 =41. To maximize sum1*sum2, we need sum1 and sum2 as close as possible to 41/2=20.5. So find two subsets of three numbers that add up to around 20.5.\n\nLet me try combinations. Let's see:\n\nPossible sums for three numbers:\n\n- 2+3+13=18\n\n- 2+5+13=20\n- 2+7+11=20\n\n- 3+5+13=21\n\n- 3+7+11=21\n\n- 5+7+11=23\n\nSo the best possible is sum1=20 and sum2=21, product 420. So the maximum sum is 420. \\boxed{420}</think>\n\n\\boxed{420}<|im_end|>\n```\n\n# RL on GSM8K\n\n```bash\npython -m tinker_cookbook.recipes.math_rl.train env=gsm8k model_name=\"meta-llama/Llama-3.1-8B-Instruct\" group_size=64 groups_per_batch=32 learning_rate=8e-5 max_tokens=1024\n```\n\nGenerally, you should observe that training reward goes above 0.8 within a few steps. After 220 steps of training, we achieve `\"test/env/all/correct\": 0.9090`. A smaller group_size (8) and larger `groups_per_batch` (64) will achieve `0.8824` accuracy in around a quarter of the time.\n\nHere is an example chain of thought,\n```\nA plane takes off at 6:00 a.m. and flies for 4 hours from New York City to Chicago. The plane stays at the port in Chicago for 1 hour and then departs for Miami. If the aircraft took three times as many hours to fly to Miami than it took to fly from New York to Chicago, calculate the total time to travel from New York to Miami. Write your answer in \\boxed{} format.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\nLet's break it down step by step:\n1. The plane flies from New York City to Chicago for 4 hours. This duration is given.\n2. The plane stays at the port in Chicago for 1 hour.\n3. The time it takes to fly from Chicago to Miami is three times the time it took to fly from New York to Chicago, which is 4 hours. So, the time to fly from Chicago to Miami is 3 * 4 = 12 hours.\nNow, let's calculate the total time:\n* Flight from New York City to Chicago: 4 hours\n* Stay at the port in Chicago: 1 hour\n* Flight from Chicago to Miami: 12 hours\nTotal time = 4 + 1 + 12 = 17 hours\nSo, the total time to travel from New York to Miami is 17 hours.\n\\boxed{17}<|eot_id|>\n```\n"
  },
  {
    "path": "tinker_cookbook/recipes/math_rl/arithmetic_env.py",
    "content": "from collections.abc import Sequence\nfrom functools import partial\n\nimport chz\nimport numpy as np\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.rl.problem_env import ProblemEnv, ProblemGroupBuilder\nfrom tinker_cookbook.rl.types import EnvGroupBuilder, RLDataset, RLDatasetBuilder\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\nclass ArithmeticEnv(ProblemEnv):\n    \"\"\"\n    A toy environment for solving addition problems.\n    \"\"\"\n\n    def __init__(\n        self,\n        x: int,\n        y: int,\n        renderer: renderers.Renderer,\n        convo_prefix: list[renderers.Message] | None = None,\n    ):\n        super().__init__(renderer, convo_prefix)\n        self.x = x\n        self.y = y\n\n    def get_question(self) -> str:\n        return f\"What is {self.x} + {self.y}?\"\n\n    def check_answer(self, sample_str: str) -> bool:\n        chunks = sample_str.split()\n        try:\n            answer = int(chunks[0])\n        except (ValueError, IndexError):\n            return False\n        return answer == self.x + self.y\n\n    def check_format(self, sample_str: str) -> bool:\n        return True\n\n    def get_reference_answer(self) -> str:\n        return str(self.x + self.y)\n\n    @staticmethod\n    def standard_fewshot_prefix() -> list[renderers.Message]:\n        return [\n            {\"role\": \"user\", \"content\": \"What is 4 + 5?\"},\n            {\"role\": \"assistant\", \"content\": \"9\"},\n        ]\n\n\nclass ArithmeticDataset(RLDataset):\n    def __init__(\n        self,\n        batch_size: int,\n        renderer: renderers.Renderer,\n        group_size: int,\n        n_batches: int = 100,\n        include_fewshot: bool = True,\n    ):\n        self._rng = np.random.RandomState(None)\n        self.batch_size = batch_size\n        self.group_size = group_size\n        self.renderer = renderer\n        self.n_batches = n_batches\n        self.include_fewshot = include_fewshot\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        self._rng.seed(index)\n        return [self._make_env_group_builder(self._rng) for _ in range(self.batch_size)]\n\n    def _make_env_group_builder(self, rng: np.random.RandomState) -> ProblemGroupBuilder:\n        x = rng.randint(0, 101)\n        y = rng.randint(0, 101)\n        convo_prefix = ArithmeticEnv.standard_fewshot_prefix() if self.include_fewshot else None\n        return ProblemGroupBuilder(\n            env_thunk=partial(\n                ArithmeticEnv, x, y, convo_prefix=convo_prefix, renderer=self.renderer\n            ),\n            num_envs=self.group_size,\n        )\n\n    def __len__(self) -> int:\n        return self.n_batches\n\n\n@chz.chz\nclass ArithmeticDatasetBuilder(RLDatasetBuilder):\n    batch_size: int\n    model_name_for_tokenizer: str\n    renderer_name: str\n    n_batches: int\n    group_size: int\n    include_fewshot: bool = True\n\n    async def __call__(self) -> tuple[ArithmeticDataset, None]:\n        tokenizer = get_tokenizer(self.model_name_for_tokenizer)\n        return ArithmeticDataset(\n            batch_size=self.batch_size,\n            renderer=renderers.get_renderer(self.renderer_name, tokenizer=tokenizer),\n            n_batches=self.n_batches,\n            include_fewshot=self.include_fewshot,\n            group_size=self.group_size,\n        ), None\n"
  },
  {
    "path": "tinker_cookbook/recipes/math_rl/math_env.py",
    "content": "import math\nimport re\nfrom collections.abc import Sequence\nfrom functools import partial\nfrom typing import Literal, cast\n\nimport chz\nfrom datasets import Dataset, concatenate_datasets, get_dataset_config_names, load_dataset\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.recipes.math_rl.math_grading import (\n    extract_boxed,\n    grade_answer,\n    grade_answer_math_verify,\n    run_with_timeout_signal,\n)\nfrom tinker_cookbook.rl.problem_env import ProblemEnv, ProblemGroupBuilder, logger\nfrom tinker_cookbook.rl.types import EnvGroupBuilder, RLDataset, RLDatasetBuilder\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\nclass MathEnv(ProblemEnv):\n    def __init__(\n        self,\n        problem: str,\n        answer: str,\n        renderer: renderers.Renderer,\n        convo_prefix: list[renderers.Message] | None = None,\n        grader: Literal[\"sympy\", \"math_verify\"] = \"sympy\",\n        timeout: float = 1.0,\n    ):\n        super().__init__(renderer, convo_prefix)\n        self.problem = problem\n        self.answer = answer\n        self.grader = grader\n        self.timeout = timeout\n\n    @classmethod\n    def question_suffix(cls) -> str:\n        return \" Write your answer in \\\\boxed{} format.\"\n\n    def get_question(self) -> str:\n        return self.problem + self.question_suffix()\n\n    def check_format(self, sample_str: str) -> bool:\n        try:\n            _ = extract_boxed(sample_str)\n            return True\n        except ValueError:\n            return False\n\n    def check_answer(self, sample_str: str) -> bool:\n        try:\n            answer = extract_boxed(sample_str)\n        except ValueError:\n            return False\n        return safe_grade(answer, self.answer, self.grader, self.timeout)\n\n    def get_reference_answer(self) -> str:\n        return self.answer\n\n    @staticmethod\n    def standard_fewshot_prefix() -> list[renderers.Message]:\n        return [\n            {\n                \"role\": \"user\",\n                \"content\": \"How many r's are in strawberry?\" + MathEnv.question_suffix(),\n            },\n            {\n                \"role\": \"assistant\",\n                \"content\": \"Let's spell the word out and number all the letters: 1) s 2) t 3) r 4) a 5) w 6) b 7) e 8) r 9) r 10) y. We have r's at positions 3, 8, and 9. \\\\boxed{3}\",\n            },\n        ]\n\n\ndef safe_grade(given_answer: str, ground_truth: str, grader: str = \"sympy\", timeout: float = 1.0):\n    if grader == \"sympy\":\n        grader_func = grade_answer\n    elif grader == \"math_verify\":\n        grader_func = grade_answer_math_verify\n    else:\n        raise ValueError(f\"Invalid grader: {grader}\")\n    out = run_with_timeout_signal(\n        grader_func, args=(given_answer, ground_truth), timeout_seconds=int(math.ceil(timeout))\n    )\n    if out is None:\n        logger.warning(f\"Timeout grading {given_answer} against {ground_truth}\")\n        return False\n    return out\n\n\ndef extract_gsm8k_final_answer(text: str) -> str:\n    \"\"\"Extract the final numeric/string answer from a GSM8K solution field.\n\n    GSM8K format typically places the final answer on a line starting with\n    '####'. We take the substring following '####' on the last such line.\n    \"\"\"\n    lines = text.splitlines()\n    for line in reversed(lines):\n        s = line.strip()\n        if s.startswith(\"####\"):\n            content = s[4:].strip()\n            if content.startswith(\":\"):\n                content = content[1:].strip()\n            content = content.replace(\",\", \"\").strip()\n            return content\n    matches = re.findall(r\"####\\s*(.+)\", text)\n    if matches:\n        return matches[-1].strip()\n    raise ValueError(\"No GSM8K final answer found\")\n\n\ndef _get_hendrycks_math_test() -> Dataset:\n    test_dataset = load_dataset(\"HuggingFaceH4/MATH-500\", name=\"default\", split=\"test\")\n    return cast(Dataset, test_dataset)\n\n\ndef _get_hendrycks_math_train() -> Dataset:\n    # For Hendrycks MATH, the standard is to use both the \"train\" and \"test\" splits for\n    # training. The \"test\" split here is NOT the same as the MATH-500 test split above,\n    # which is a commonly-held-out subset of 500 of the below 12.5k problems. To construct\n    # a clean training set, we filter out problems that exist in the MATH-500 test set,\n    # resulting in 12000 train and 500 test problems.\n\n    test_problems: set[str] = {\n        problem[\"problem\"]  # pyright: ignore[reportArgumentType, reportCallIssue]\n        for problem in _get_hendrycks_math_test()\n    }\n\n    dataset_name = \"EleutherAI/hendrycks_math\"\n    configs = get_dataset_config_names(dataset_name)\n    pieces = []\n    for cfg in configs:\n        for split in (\"train\", \"test\"):\n            ds = load_dataset(dataset_name, name=cfg, split=split)\n            ds = ds.filter(lambda example: example[\"problem\"] not in test_problems)\n            pieces.append(ds)\n    full_dataset = concatenate_datasets(pieces)\n\n    return full_dataset\n\n\nclass MathDataset(RLDataset):\n    def __init__(\n        self,\n        batch_size: int,\n        group_size: int,\n        renderer: renderers.Renderer,\n        convo_prefix: list[renderers.Message] | None = None,\n        split: Literal[\"train\", \"test\"] = \"train\",\n        seed: int = 0,\n    ):\n        if split == \"train\":\n            self.ds = _get_hendrycks_math_train().shuffle(seed=seed)\n        elif split == \"test\":\n            self.ds = _get_hendrycks_math_test()\n        self.batch_size = batch_size\n        self.group_size = group_size if split == \"train\" else 1\n        self.renderer = renderer\n        self.convo_prefix = convo_prefix\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        batch_start = index * self.batch_size\n        batch_end = min((index + 1) * self.batch_size, len(self.ds))\n        assert batch_start < batch_end, \"Incorrect batch size\"\n        return [\n            builder\n            for row in self.ds.select(range(batch_start, batch_end))\n            if (builder := self._make_env_group_builder(row, self.group_size)) is not None  # pyright: ignore[reportArgumentType]\n        ]\n\n    def __len__(self) -> int:\n        return math.ceil(len(self.ds) / self.batch_size)\n\n    def _make_env_group_builder(\n        self, x: dict[str, str], group_size: int\n    ) -> ProblemGroupBuilder | None:\n        try:\n            answer = extract_boxed(x[\"solution\"])\n        except ValueError:  # not sure if this happens\n            logger.warning(f\"No answer found for {x['solution']}\")\n            return None\n        return ProblemGroupBuilder(\n            env_thunk=partial(\n                MathEnv, x[\"problem\"], answer, self.renderer, convo_prefix=self.convo_prefix\n            ),\n            num_envs=group_size,\n        )\n\n\n@chz.chz\nclass MathDatasetBuilder(RLDatasetBuilder):\n    batch_size: int\n    model_name_for_tokenizer: str\n    renderer_name: str\n    group_size: int\n    convo_prefix: list[renderers.Message] | None | Literal[\"standard\"] = \"standard\"\n    seed: int = 0\n\n    async def __call__(self) -> tuple[MathDataset, MathDataset]:\n        if self.convo_prefix == \"standard\":\n            convo_prefix = MathEnv.standard_fewshot_prefix()\n        else:\n            convo_prefix = self.convo_prefix\n        tokenizer = get_tokenizer(self.model_name_for_tokenizer)\n        renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer)\n        datasets = [\n            MathDataset(\n                batch_size=self.batch_size,\n                group_size=self.group_size,\n                renderer=renderer,\n                convo_prefix=convo_prefix,\n                split=split,\n                seed=self.seed,\n            )\n            for split in (\"train\", \"test\")\n        ]\n        return (datasets[0], datasets[1])\n\n\nclass PolarisDataset(MathDataset):\n    def __init__(\n        self,\n        batch_size: int,\n        group_size: int,\n        renderer: renderers.Renderer,\n        convo_prefix: list[renderers.Message] | None = None,\n        seed: int = 0,\n    ):\n        # Don't call super().__init__ since we're overriding the dataset loading\n        self.ds = load_dataset(\"POLARIS-Project/Polaris-Dataset-53K\", split=\"train\").shuffle(\n            seed=seed\n        )\n        self.batch_size = batch_size\n        self.group_size = group_size\n        self.renderer = renderer\n        self.convo_prefix = convo_prefix\n\n    def _make_env_group_builder(\n        self, x: dict[str, str], group_size: int\n    ) -> ProblemGroupBuilder | None:\n        # Extract problem and answer from the dataset\n        problem = x.get(\"problem\", \"\")\n        answer = x.get(\"answer\", \"\")\n        if not (problem and answer):\n            return None\n        return ProblemGroupBuilder(\n            env_thunk=partial(\n                MathEnv, problem, answer, self.renderer, convo_prefix=self.convo_prefix\n            ),\n            num_envs=group_size,\n            dataset_name=\"polaris\",\n        )\n\n\n@chz.chz\nclass PolarisDatasetBuilder(RLDatasetBuilder):\n    batch_size: int\n    model_name_for_tokenizer: str\n    renderer_name: str\n    group_size: int\n    seed: int = 0\n\n    async def __call__(self) -> tuple[PolarisDataset, None]:\n        tokenizer = get_tokenizer(self.model_name_for_tokenizer)\n        return PolarisDataset(\n            batch_size=self.batch_size,\n            group_size=self.group_size,\n            renderer=renderers.get_renderer(self.renderer_name, tokenizer=tokenizer),\n            seed=self.seed,\n        ), None\n\n\nclass DeepMathDataset(MathDataset):\n    def __init__(\n        self,\n        batch_size: int,\n        group_size: int,\n        renderer: renderers.Renderer,\n        convo_prefix: list[renderers.Message] | None = None,\n        seed: int = 0,\n    ):\n        # Don't call super().__init__ since we're overriding the dataset loading\n        self.ds = load_dataset(\"zwhe99/DeepMath-103K\", split=\"train\").shuffle(seed=seed)\n        self.batch_size = batch_size\n        self.group_size = group_size\n        self.renderer = renderer\n        self.convo_prefix = convo_prefix\n\n    def _make_env_group_builder(\n        self, x: dict[str, str], group_size: int\n    ) -> ProblemGroupBuilder | None:\n        # Extract problem and answer from the dataset\n        problem = x.get(\"question\", \"\")\n        answer = x.get(\"final_answer\", \"\")\n        if not (problem and answer):\n            return None\n        return ProblemGroupBuilder(\n            env_thunk=partial(\n                MathEnv, problem, answer, self.renderer, convo_prefix=self.convo_prefix\n            ),\n            num_envs=group_size,\n            dataset_name=\"deepmath\",\n        )\n\n\n@chz.chz\nclass DeepMathDatasetBuilder(RLDatasetBuilder):\n    batch_size: int\n    model_name_for_tokenizer: str\n    renderer_name: str\n    group_size: int\n    seed: int = 0\n\n    async def __call__(self) -> tuple[DeepMathDataset, None]:\n        tokenizer = get_tokenizer(self.model_name_for_tokenizer)\n        return DeepMathDataset(\n            batch_size=self.batch_size,\n            group_size=self.group_size,\n            renderer=renderers.get_renderer(self.renderer_name, tokenizer=tokenizer),\n            seed=self.seed,\n        ), None\n\n\nclass Gsm8kDataset(RLDataset):\n    def __init__(\n        self,\n        batch_size: int,\n        group_size: int,\n        renderer: renderers.Renderer,\n        convo_prefix: list[renderers.Message] | None = None,\n        split: Literal[\"train\", \"test\"] = \"train\",\n        seed: int = 0,\n    ):\n        if split not in (\"train\", \"test\"):\n            raise ValueError(\"split must be 'train' or 'test'\")\n        self.ds = cast(Dataset, load_dataset(\"openai/gsm8k\", name=\"main\", split=split))\n        if split == \"train\":\n            self.ds = self.ds.shuffle(seed=seed)\n        self.batch_size = batch_size\n        self.group_size = group_size if split == \"train\" else 1\n        self.renderer = renderer\n        self.convo_prefix = convo_prefix\n\n    @classmethod\n    def question_suffix(cls) -> str:\n        return \" Provide a numerical answer without units, written inside \\\\boxed{}.\"\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        batch_start = index * self.batch_size\n        batch_end = min((index + 1) * self.batch_size, len(self.ds))\n        assert batch_start < batch_end, \"Incorrect batch size\"\n        return [\n            builder\n            for row in self.ds.select(range(batch_start, batch_end))\n            if (builder := self._make_env_group_builder(row, self.group_size)) is not None  # pyright: ignore[reportArgumentType]\n        ]\n\n    def __len__(self) -> int:\n        return math.ceil(len(self.ds) / self.batch_size)\n\n    def _make_env_group_builder(\n        self, x: dict[str, str], group_size: int\n    ) -> ProblemGroupBuilder | None:\n        try:\n            problem = x[\"question\"]\n            answer = extract_gsm8k_final_answer(x[\"answer\"])\n        except Exception as e:\n            logger.warning(f\"Failed to parse GSM8K row: {e}\")\n            return None\n        return ProblemGroupBuilder(\n            env_thunk=partial(\n                MathEnv, problem, answer, self.renderer, convo_prefix=self.convo_prefix\n            ),\n            num_envs=group_size,\n        )\n\n\n@chz.chz\nclass Gsm8kDatasetBuilder(RLDatasetBuilder):\n    batch_size: int\n    model_name_for_tokenizer: str\n    renderer_name: str\n    group_size: int\n    convo_prefix: list[renderers.Message] | None | Literal[\"standard\"] = \"standard\"\n    seed: int = 0\n\n    async def __call__(self) -> tuple[Gsm8kDataset, Gsm8kDataset]:\n        if self.convo_prefix == \"standard\":\n            convo_prefix = MathEnv.standard_fewshot_prefix()\n        else:\n            convo_prefix = self.convo_prefix\n        tokenizer = get_tokenizer(self.model_name_for_tokenizer)\n        renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer)\n        datasets = [\n            Gsm8kDataset(\n                batch_size=self.batch_size,\n                group_size=self.group_size,\n                renderer=renderer,\n                convo_prefix=convo_prefix,\n                split=split,\n                seed=self.seed,\n            )\n            for split in (\"train\", \"test\")\n        ]\n        return (datasets[0], datasets[1])\n\n\n# Populate the dataset builder map after all classes are defined\nDATASET_BUILDER_MAP = {\n    \"math\": MathDatasetBuilder,\n    \"polaris\": PolarisDatasetBuilder,\n    \"deepmath\": DeepMathDatasetBuilder,\n    \"gsm8k\": Gsm8kDatasetBuilder,\n}\n\n\ndef get_math_dataset_builder(\n    dataset_name: str,\n    batch_size: int,\n    model_name_for_tokenizer: str,\n    renderer_name: str,\n    group_size: int,\n    seed: int = 0,\n) -> RLDatasetBuilder:\n    \"\"\"\n    Unified function to get any math dataset builder.\n    Args:\n        dataset_name: One of \"math\", \"polaris\", \"deepmath\", or \"gsm8k\"\n        batch_size: Number of groups per batch\n        model_name_for_tokenizer: Model name for tokenizer\n        renderer_name: Name of the renderer to use\n        group_size: Number of environments per group\n        seed: Random seed for data shuffling (default: 0)\n    Returns:\n        The appropriate dataset builder instance\n    \"\"\"\n    if dataset_name not in DATASET_BUILDER_MAP:\n        raise ValueError(\n            f\"Unknown math dataset: {dataset_name}. Available: {list(DATASET_BUILDER_MAP.keys())}\"\n        )\n\n    builder_class = DATASET_BUILDER_MAP[dataset_name]\n\n    return builder_class(\n        batch_size=batch_size,\n        model_name_for_tokenizer=model_name_for_tokenizer,\n        renderer_name=renderer_name,\n        group_size=group_size,\n        seed=seed,\n    )\n"
  },
  {
    "path": "tinker_cookbook/recipes/math_rl/math_env_test.py",
    "content": "import asyncio\n\nfrom tinker_cookbook.recipes.math_rl.math_env import MathDatasetBuilder\n\n\ndef test_math_dataset_builder():\n    builder = MathDatasetBuilder(\n        batch_size=1,\n        model_name_for_tokenizer=\"Qwen/Qwen3-4B-Instruct-2507\",\n        renderer_name=\"qwen3_instruct\",\n        group_size=1,\n    )\n    train_dataset, test_dataset = asyncio.run(builder())\n\n    # Basic dataset statistics\n    assert len(train_dataset) == 12_000\n    assert len(test_dataset) == 500\n    assert len(train_dataset.get_batch(0)) == 1\n    assert len(test_dataset.get_batch(0)) == 1\n\n    # Check for contamination of train and test sets\n    test_questions = set()\n    for i in range(len(test_dataset)):\n        batch = test_dataset.get_batch(index=i)\n        test_questions.add(batch[0].env_thunk().get_question())  # pyright: ignore\n    for i in range(len(train_dataset)):\n        batch = train_dataset.get_batch(index=i)\n        assert batch[0].env_thunk().get_question() not in test_questions  # pyright: ignore\n\n\nif __name__ == \"__main__\":\n    test_math_dataset_builder()\n"
  },
  {
    "path": "tinker_cookbook/recipes/math_rl/math_grading.py",
    "content": "\"\"\"\nMath grading utilities for RL training.\n\nIncludes math_normalize functionality that was dependency of grader.\n\"\"\"\n\nimport contextlib\nimport logging\nimport re\nfrom collections.abc import Callable\nfrom concurrent.futures import ThreadPoolExecutor\nfrom concurrent.futures import TimeoutError as FuturesTimeoutError\nfrom typing import Any, TypeVar\n\ntry:\n    import sympy\n    from pylatexenc import latex2text\n    from sympy.parsing import sympy_parser\nexcept ImportError:\n    raise ImportError(\n        \"math-rl dependencies (sympy, pylatexenc, math-verify) are required for this recipe. \"\n        \"Install them with: uv pip install 'tinker-cookbook[math-rl] @ \"\n        \"git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly'\"\n    ) from None\n\nlogger = logging.getLogger(__name__)\n\n\nT = TypeVar(\"T\")\n\n# ======================================================================\n# Math Normalize Functions\n# ======================================================================\n\n\ndef normalize_answer(answer: str | None) -> str | None:\n    if answer is None:\n        return None\n    answer = answer.strip()\n    try:\n        # Remove enclosing `\\text{}`.\n        m = re.search(\"^\\\\\\\\text\\\\{(?P<text>.+?)\\\\}$\", answer)\n        if m is not None:\n            answer = m.group(\"text\").strip()\n        return _strip_string(str(answer))\n    except Exception:\n        return answer\n\n\ndef _fix_fracs(string: str) -> str:\n    substrs = string.split(\"\\\\frac\")\n    new_str = substrs[0]\n    if len(substrs) > 1:\n        substrs = substrs[1:]\n        for substr in substrs:\n            new_str += \"\\\\frac\"\n            if substr[0] == \"{\":\n                new_str += substr\n            else:\n                try:\n                    assert len(substr) >= 2\n                except AssertionError:\n                    return string\n                a = substr[0]\n                b = substr[1]\n                if b != \"{\":\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}{\" + b + \"}\" + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}{\" + b + \"}\"\n                else:\n                    if len(substr) > 2:\n                        post_substr = substr[2:]\n                        new_str += \"{\" + a + \"}\" + b + post_substr\n                    else:\n                        new_str += \"{\" + a + \"}\" + b\n    string = new_str\n    return string\n\n\ndef _fix_a_slash_b(string: str) -> str:\n    if len(string.split(\"/\")) != 2:\n        return string\n    a = string.split(\"/\")[0]\n    b = string.split(\"/\")[1]\n    try:\n        a = int(a)\n        b = int(b)\n        assert string == f\"{a}/{b}\"\n        new_string = \"\\\\frac{\" + str(a) + \"}{\" + str(b) + \"}\"\n        return new_string\n    except ValueError:\n        return string\n\n\ndef _remove_right_units(string: str) -> str:\n    # \"\\\\text{ \" only ever occurs (at least in the val set) when describing units\n    if \"\\\\text{ \" in string:\n        splits = string.split(\"\\\\text{ \")\n        assert len(splits) == 2\n        return splits[0]\n    else:\n        return string\n\n\ndef _fix_sqrt(string: str) -> str:\n    if \"\\\\sqrt\" not in string:\n        return string\n    splits = string.split(\"\\\\sqrt\")\n    new_string = splits[0]\n    for split in splits[1:]:\n        if split[0] != \"{\":\n            a = split[0]\n            new_substr = \"\\\\sqrt{\" + a + \"}\" + split[1:]\n        else:\n            new_substr = \"\\\\sqrt\" + split\n        new_string += new_substr\n    return new_string\n\n\ndef _strip_string(string: str) -> str:\n    # linebreaks\n    string = string.replace(\"\\n\", \"\")\n\n    # remove inverse spaces\n    string = string.replace(\"\\\\!\", \"\")\n\n    # replace \\\\ with \\\n    string = string.replace(\"\\\\\\\\\", \"\\\\\")\n\n    # replace tfrac and dfrac with frac\n    string = string.replace(\"tfrac\", \"frac\")\n    string = string.replace(\"dfrac\", \"frac\")\n\n    # remove \\left and \\right\n    string = string.replace(\"\\\\left\", \"\")\n    string = string.replace(\"\\\\right\", \"\")\n\n    # Remove circ (degrees)\n    string = string.replace(\"^{\\\\circ}\", \"\")\n    string = string.replace(\"^\\\\circ\", \"\")\n\n    # remove dollar signs\n    string = string.replace(\"\\\\$\", \"\")\n\n    # remove units (on the right)\n    string = _remove_right_units(string)\n\n    # remove percentage\n    string = string.replace(\"\\\\%\", \"\")\n    string = string.replace(r\"\\%\", \"\")\n\n    # \" 0.\" equivalent to \" .\" and \"{0.\" equivalent to \"{.\" Alternatively, add \"0\" if \".\" is the start of the string\n    string = string.replace(\" .\", \" 0.\")\n    string = string.replace(\"{.\", \"{0.\")\n    # if empty, return empty string\n    if len(string) == 0:\n        return string\n    if string[0] == \".\":\n        string = \"0\" + string\n\n    # to consider: get rid of e.g. \"k = \" or \"q = \" at beginning\n    if len(string.split(\"=\")) == 2 and len(string.split(\"=\")[0]) <= 2:\n        string = string.split(\"=\")[1]\n\n    # fix sqrt3 --> sqrt{3}\n    string = _fix_sqrt(string)\n\n    # remove spaces\n    string = string.replace(\" \", \"\")\n\n    # \\frac1b or \\frac12 --> \\frac{1}{b} and \\frac{1}{2}, etc. Even works with \\frac1{72} (but not \\frac{72}1). Also does a/b --> \\\\frac{a}{b}\n    string = _fix_fracs(string)\n\n    # manually change 0.5 --> \\frac{1}{2}\n    if string == \"0.5\":\n        string = \"\\\\frac{1}{2}\"\n\n    # NOTE: X/Y changed to \\frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y\n    string = _fix_a_slash_b(string)\n\n    return string\n\n\n# ======================================================================\n# Extract Boxed Functions\n# ======================================================================\n\n\ndef extract_boxed(text: str) -> str:\n    \"\"\"\n    Extract the context of the last \\\\boxed{...} in the text.\n    Used for getting answers from hendrycks math\n    \"\"\"\n    boxed_strs = []\n    stack = []\n    for ichar in range(len(text)):\n        if text[ichar] == \"{\":\n            stack.append(ichar)\n        elif text[ichar] == \"}\":\n            if len(stack) == 0:\n                raise ValueError(\"Unmatched }\")\n            last_open_start = stack.pop()\n            # check if start is preceded by \\boxed\n            if text[:last_open_start].endswith(\"\\\\boxed\"):\n                boxed_strs.append(text[last_open_start + 1 : ichar])\n    if len(boxed_strs) > 0:\n        return boxed_strs[-1]\n    else:\n        # maybe there's something like '\\boxed 2' without curly braces\n        match = re.search(r\"\\\\boxed\\s+([a-zA-Z0-9]+)\", text)\n        if match:\n            return match.group(1)\n        else:\n            raise ValueError(\"No boxed strings found\")\n\n\n# ======================================================================\n# Grader Functions\n# ======================================================================\n\n# sympy might hang -- we don't care about trying to be lenient in these cases\nBAD_SUBSTRINGS = [\"^{\", \"^(\"]\nBAD_REGEXES = [r\"\\^[0-9]+\\^\", r\"\\^[0-9][0-9]+\"]\nTUPLE_CHARS = \"()[]\"\n\n\ndef _sympy_parse(expr: str):\n    \"\"\"Parses an expression with sympy.\"\"\"\n    py_expr = expr.replace(\"^\", \"**\")\n    return sympy_parser.parse_expr(\n        py_expr,\n        transformations=(\n            sympy_parser.standard_transformations\n            + (sympy_parser.implicit_multiplication_application,)\n        ),\n    )\n\n\ndef _parse_latex(expr: str) -> str:\n    \"\"\"Attempts to parse latex to an expression sympy can read.\"\"\"\n    expr = expr.replace(\"\\\\tfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\dfrac\", \"\\\\frac\")\n    expr = expr.replace(\"\\\\frac\", \" \\\\frac\")  # Play nice with mixed numbers.\n    expr = latex2text.LatexNodes2Text().latex_to_text(expr)\n\n    # Replace the specific characters that this parser uses.\n    expr = expr.replace(\"√\", \"sqrt\")\n    expr = expr.replace(\"π\", \"pi\")\n    expr = expr.replace(\"∞\", \"inf\")\n    expr = expr.replace(\"∪\", \"U\")\n    expr = expr.replace(\"·\", \"*\")\n    expr = expr.replace(\"×\", \"*\")\n\n    return expr.strip()\n\n\ndef _is_float(num: str) -> bool:\n    try:\n        float(num)\n        return True\n    except (ValueError, OverflowError):\n        return False\n\n\ndef _is_int(x: float) -> bool:\n    try:\n        return abs(x - int(round(x))) <= 1e-7\n    except (ValueError, OverflowError):\n        return False\n\n\ndef _is_frac(expr: str) -> bool:\n    return bool(re.search(r\"^-?[0-9]+.?/0*[1-9][0-9]*.?$\", expr))\n\n\ndef _str_is_int(x_str: str) -> bool:\n    try:\n        x_str = _strip_properly_formatted_commas(x_str)\n        x = float(x_str)\n        return abs(x - int(round(x))) <= 1e-7\n    except (ValueError, OverflowError):\n        return False\n\n\ndef _str_to_int(x_str: str) -> int:\n    x_str = x_str.replace(\",\", \"\")\n    x = float(x_str)\n    return int(x)\n\n\ndef _inject_implicit_mixed_number(step: str):\n    \"\"\"\n    Automatically make a mixed number evalable\n    e.g. 7 3/4 => 7+3/4\n    \"\"\"\n    p1 = re.compile(\"([0-9]) +([0-9])\")\n    step = p1.sub(\"\\\\1+\\\\2\", step)  ## implicit mults\n    return step\n\n\ndef _strip_properly_formatted_commas(expr: str):\n    # We want to be careful because we don't want to strip tuple commas\n    p1 = re.compile(r\"(\\d)(,)(\\d\\d\\d)($|\\D)\")\n    while True:\n        next_expr = p1.sub(\"\\\\1\\\\3\\\\4\", expr)\n        if next_expr == expr:\n            break\n        expr = next_expr\n    return next_expr\n\n\ndef _normalize(expr: str) -> str:\n    \"\"\"Normalize answer expressions.\"\"\"\n    if expr is None:\n        return None\n\n    # Remove enclosing `\\text{}`.\n    m = re.search(\"^\\\\\\\\text\\\\{(?P<text>.+?)\\\\}$\", expr)\n    if m is not None:\n        expr = m.group(\"text\")\n\n    expr = expr.replace(\"\\\\%\", \"%\")\n    expr = expr.replace(\"\\\\$\", \"$\")\n    expr = expr.replace(\"$\", \"\")\n    expr = expr.replace(\"%\", \"\")\n    expr = expr.replace(\" or \", \" , \")\n    expr = expr.replace(\" and \", \" , \")\n\n    expr = expr.replace(\"million\", \"*10^6\")\n    expr = expr.replace(\"billion\", \"*10^9\")\n    expr = expr.replace(\"trillion\", \"*10^12\")\n\n    for unit in [\n        \"degree\",\n        \"cm\",\n        \"centimeter\",\n        \"meter\",\n        \"mile\",\n        \"second\",\n        \"minute\",\n        \"hour\",\n        \"day\",\n        \"week\",\n        \"month\",\n        \"year\",\n        \"foot\",\n        \"feet\",\n        \"inch\",\n        \"yard\",\n    ]:\n        expr = re.sub(rf\"{unit}(es)?(s)? *(\\^[0-9]+)?\", \"\", expr)\n    expr = re.sub(\"\\\\^ *\\\\\\\\circ\", \"\", expr)\n\n    if len(expr) > 0 and expr[0] == \"{\" and expr[-1] == \"}\":\n        expr = expr[1:-1]\n\n    expr = re.sub(\",\\\\\\\\! *\", \"\", expr)\n    if _is_float(expr) and _is_int(float(expr)):\n        expr = str(int(round(float(expr))))\n    if \"\\\\\" in expr:\n        with contextlib.suppress(Exception):\n            expr = _parse_latex(expr)\n\n    # edge case with mixed numbers and negative signs\n    expr = re.sub(\"- *\", \"-\", expr)\n\n    expr = _inject_implicit_mixed_number(expr)\n    expr = expr.replace(\" \", \"\")\n\n    # if we somehow still have latex braces here, just drop them\n    expr = expr.replace(\"{\", \"\")\n    expr = expr.replace(\"}\", \"\")\n\n    # don't be case sensitive for text answers\n    expr = expr.lower()\n\n    if _str_is_int(expr):\n        expr = str(_str_to_int(expr))\n\n    return expr\n\n\ndef count_unknown_letters_in_expr(expr: str):\n    expr = expr.replace(\"sqrt\", \"\")\n    expr = expr.replace(\"frac\", \"\")\n    letters_in_expr = {x for x in expr if x.isalpha()}\n    return len(letters_in_expr)\n\n\ndef should_allow_eval(expr: str):\n    # we don't want to try parsing unknown text or functions of more than two variables\n    if count_unknown_letters_in_expr(expr) > 2:\n        return False\n\n    for bad_string in BAD_SUBSTRINGS:\n        if bad_string in expr:\n            return False\n\n    return all(re.search(bad_regex, expr) is not None for bad_regex in BAD_REGEXES)\n\n\ndef are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):\n    are_equal = False\n    try:\n        expr = f\"({ground_truth_normalized})-({given_normalized})\"\n        if should_allow_eval(expr):\n            sympy_diff = _sympy_parse(expr)\n            simplified = sympy.simplify(sympy_diff)\n            if simplified == 0:\n                are_equal = True\n    except Exception:\n        pass\n    return are_equal\n\n\ndef split_tuple(expr: str):\n    \"\"\"\n    Split the elements in a tuple/interval, while handling well-formatted commas in large numbers\n    \"\"\"\n    expr = _strip_properly_formatted_commas(expr)\n    if len(expr) == 0:\n        return []\n    if (\n        len(expr) > 2\n        and expr[0] in TUPLE_CHARS\n        and expr[-1] in TUPLE_CHARS\n        and all(ch not in expr[1:-1] for ch in TUPLE_CHARS)\n    ):\n        elems = [elem.strip() for elem in expr[1:-1].split(\",\")]\n    else:\n        elems = [expr]\n    return elems\n\n\ndef grade_answer(given_answer: str, ground_truth: str) -> bool:\n    \"\"\"\n    The answer will be considered correct if:\n    (a) it normalizes to the same string as the ground truth answer\n    OR\n    (b) sympy can simplify the difference between the expressions to 0\n    \"\"\"\n    if given_answer is None:\n        return False\n\n    ground_truth_normalized_mathd = normalize_answer(ground_truth)\n    given_answer_normalized_mathd = normalize_answer(given_answer)\n\n    # be at least as lenient as mathd\n    if ground_truth_normalized_mathd == given_answer_normalized_mathd:\n        return True\n\n    ground_truth_normalized = _normalize(ground_truth)\n    given_normalized = _normalize(given_answer)\n\n    if ground_truth_normalized is None:\n        return False\n\n    if ground_truth_normalized == given_normalized:\n        return True\n\n    if len(given_normalized) == 0:\n        return False\n\n    ground_truth_elems = split_tuple(ground_truth_normalized)\n    given_elems = split_tuple(given_normalized)\n\n    is_correct = False\n\n    if (\n        len(ground_truth_elems) > 1\n        and (\n            ground_truth_normalized[0] != given_normalized[0]\n            or ground_truth_normalized[-1] != given_normalized[-1]\n        )\n    ) or len(ground_truth_elems) != len(given_elems):\n        is_correct = False\n    else:\n        for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True):\n            if _is_frac(ground_truth_elem) and _is_frac(given_elem):\n                # if fractions aren't reduced, then shouldn't be marked as correct\n                # so, we don't want to allow sympy.simplify in this case\n                is_correct = ground_truth_elem == given_elem\n            elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):\n                # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)\n                is_correct = False\n            else:\n                is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)\n            if not is_correct:\n                break\n\n    return is_correct\n\n\ndef grade_answer_math_verify(given_answer: str, ground_truth: str) -> bool:\n    \"\"\"\n    Use the math_verify package to verify the answer.\n    \"\"\"\n    from math_verify import parse, verify\n\n    #   Make sure the answer is wrapped in $ if it already isn't otherwise it is not parsed correctly\n    if not given_answer.startswith(\"$\") and not given_answer.endswith(\"$\"):\n        given_answer = f\"${given_answer}$\"\n    if not ground_truth.startswith(\"$\") and not ground_truth.endswith(\"$\"):\n        ground_truth = f\"${ground_truth}$\"\n\n    given_answer_parsed = parse(given_answer)\n    ground_truth_parsed = parse(ground_truth)\n\n    is_correct = verify(given_answer_parsed, ground_truth_parsed)\n\n    return is_correct\n\n\n# ======================================================================\n# Timeout Functions\n# ======================================================================\n\n\n# Define a custom exception for timeouts\nclass TimeoutException(Exception):\n    pass\n\n\ndef run_with_timeout_signal(\n    func: Callable[..., T],\n    args: tuple[Any, ...] = (),\n    kwargs: dict[str, Any] | None = None,\n    timeout_seconds: int = 5,\n) -> T | None:\n    \"\"\"\n    Runs a function with a timeout using ThreadPoolExecutor (cross-platform).\n\n    Args:\n        func: The function to execute.\n        args: Positional arguments for the function.\n        kwargs: Keyword arguments for the function.\n        timeout_seconds: Maximum time allowed in seconds.\n\n    Returns:\n        The result of the function call, or None if it times out.\n    \"\"\"\n    if kwargs is None:\n        kwargs = {}\n    with ThreadPoolExecutor(max_workers=1) as executor:\n        future = executor.submit(func, *args, **kwargs)\n        try:\n            result = future.result(timeout=timeout_seconds)\n        except FuturesTimeoutError:\n            logger.warning(f\"Function timed out after {timeout_seconds} seconds.\")\n            result = None\n        except Exception as e:\n            # Handle other exceptions from the function if needed\n            logger.warning(f\"Function raised an exception: {e}\")\n            result = None  # Or re-raise\n\n    return result\n"
  },
  {
    "path": "tinker_cookbook/recipes/math_rl/train.py",
    "content": "import asyncio\nimport logging\nfrom datetime import datetime\nfrom typing import Any\n\nimport chz\nfrom tinker.types import LossFnType\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.recipes.math_rl import (\n    arithmetic_env,\n    math_env,\n)\nfrom tinker_cookbook.rl.train import AsyncConfig, Config, StreamMinibatchConfig, main\nfrom tinker_cookbook.rl.types import RLDatasetBuilder\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Simple command-line configuration for RL training.\"\"\"\n\n    # Model configuration\n    model_name: str = \"meta-llama/Llama-3.1-8B-Instruct\"\n    lora_rank: int = 32\n    renderer_name: str | None = None\n    load_checkpoint_path: str | None = None\n\n    # Environment configuration\n    env: str = \"arithmetic\"  # Options: arithmetic, math, polaris, deepmath, gsm8k\n    seed: int = 0  # Random seed for data shuffling\n\n    # Training hyperparameters\n    group_size: int = 4\n    groups_per_batch: int = 100\n    learning_rate: float = 1e-5\n    max_tokens: int = 5\n    temperature: float = 1.0\n    kl_penalty_coef: float = 0.0\n\n    # Number of optimizer steps per training iteration.\n    # Useful for very large batch sizes.\n    num_substeps: int = 1\n\n    # Logging configuration\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    compute_post_kl: bool = False\n\n    # Evals\n    eval_every: int = 20\n\n    # Checkpointing\n    save_every: int = 20\n\n    # Service configuration\n    base_url: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps_off_policy: int | None = None\n\n    # Stream minibatch: train on minibatches as soon as they are ready\n    # instead of waiting for the full batch.\n    stream_minibatch_config: StreamMinibatchConfig | None = None\n\n    # Loss function and configuration.\n    # See https://tinker-docs.thinkingmachines.ai/losses\n    loss_fn: LossFnType = \"importance_sampling\"\n    loss_fn_config: dict[str, Any] | None = None\n\n    max_steps: int | None = None\n\n\ndef get_dataset_builder(\n    env: str,\n    batch_size: int,\n    model_name: str,\n    renderer_name: str,\n    group_size: int,\n    seed: int = 0,\n) -> RLDatasetBuilder:\n    if env == \"arithmetic\":\n        return arithmetic_env.ArithmeticDatasetBuilder(\n            batch_size=batch_size,\n            model_name_for_tokenizer=model_name,\n            renderer_name=renderer_name,\n            n_batches=100,\n            include_fewshot=True,\n            group_size=group_size,\n        )\n    elif env in [\"math\", \"polaris\", \"deepmath\", \"gsm8k\"]:\n        return math_env.get_math_dataset_builder(\n            dataset_name=env,\n            batch_size=batch_size,\n            model_name_for_tokenizer=model_name,\n            renderer_name=renderer_name,\n            group_size=group_size,\n            seed=seed,\n        )\n    else:\n        raise ValueError(f\"Unknown environment: {env}\")\n\n\nasync def cli_main(cli_config: CLIConfig):\n    \"\"\"Convert CLI config to full config and run training.\"\"\"\n\n    # Get tokenizer for stop sequences\n    renderer_name = await checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n    model_name = cli_config.model_name.replace(\"/\", \"-\")\n    run_name = f\"{cli_config.env}-{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.group_size}group-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n    # create log path if it doesn't exist\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/math_rl/{run_name}\"\n\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n    # Create full config\n    config = Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_builder=get_dataset_builder(\n            env=cli_config.env,\n            batch_size=cli_config.groups_per_batch,\n            model_name=cli_config.model_name,\n            renderer_name=renderer_name,\n            group_size=cli_config.group_size,\n            seed=cli_config.seed,\n        ),\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        temperature=cli_config.temperature,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        log_path=log_path,\n        base_url=cli_config.base_url,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        compute_post_kl=cli_config.compute_post_kl,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        num_substeps=cli_config.num_substeps,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        async_config=AsyncConfig(\n            max_steps_off_policy=cli_config.max_steps_off_policy,\n            groups_per_batch=cli_config.groups_per_batch,\n        )\n        if cli_config.max_steps_off_policy is not None\n        else None,\n        stream_minibatch_config=StreamMinibatchConfig(\n            groups_per_batch=cli_config.groups_per_batch,\n            num_minibatches=cli_config.stream_minibatch_config.num_minibatches,\n        )\n        if cli_config.stream_minibatch_config is not None\n        else None,\n        loss_fn=cli_config.loss_fn,\n        loss_fn_config=cli_config.loss_fn_config,\n        max_steps=cli_config.max_steps,\n    )\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    # Run training\n    await main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/README.md",
    "content": "# Multiturn Training\n\nOften we not only want large language models (LLMs) to generate a single response, but also to perform well across multiple turns of interaction.\nTo help Tinker users easily customize their own training, we provide the *Environment* abstraction.\n\nWe cover three examples, with increasing complexity.\n1. [Guess the number](./guess_number/): where the policy learns to guess the target number with multiple tries, given feedback on whether the number is too high or low.\n2. [Twenty Questions](./twenty_questions): where the policy learns to guess an underlying object by asking yes/no questions.\n3. [Tic-Tac-Toe](./text_arena): where the policy learns by playing against itself.\n\nThe first example is the simplest, since the user turn can be programmed with simple python statements.\nThe second is more complicated, since we need a language model to answer yes/no questions to the policy.\nThe third one is the most complicated, since we need to train multiple LLMs at the same time.\nFortunately, our `Environment` abstraction can handle all of them, and we will show examples of how to implement each one.\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/guess_number/README.md",
    "content": "# A Simple `Environment` for Guessing the Number\n\n```bash\npython -m tinker_cookbook.recipes.multiplayer_rl.guess_number.train\n```\n\nThe `test/env/all/reward/total` should increase from ~40% to >=50% in 20 steps.\n\n### Background: Guess the Number\n\nIn this task, we train an LLM to guess a hidden integer number between 0 and 1024.\n- If the LLM guess correctly, the task is done\n- If the LLM guesses incorrectly, the LLM will be given the chance to guess again and information whether the guessed number is too low or high.\n- The interaction automatically ends after 10 guesses.\nThe LLM is rewarded with 1 if it guesses correctly, otherwise 0.\n\nHere is one example game if the correct guess is 640:\n```\n[User]: Guess a number between 0 and 1024.\n[LLM]: Guess: 512\n[User]: Too low.\n[LLM]: Guess: 768\n[User]: Too high.\n[LLM]: Guess: 640\n[User]: Correct!\n```\n\n### Defining a `Guess-the-Number` Environment in Reinforcement Learning (RL)\n\nIn RL [1] (or more accurately, POMDP [2]), we need to specify the following components to define an environment:\n\n| Component | Description |\n|-----------|-------------|\n| 1. **Action Space** | The tokens generated by the LLM |\n| 2. **Observation Space** | The tokens that the LLM sees |\n| 3. **Initial Observation** | The initial tokens that the LLM sees |\n| 4. **Transition Function** | How the LLM-generated tokens determine what the user would say (e.g. Correct, Too high/low) |\n| 5. **Reward Function** | Whether the LLM has output the correct guess |\n\nThe action space and observation space are the same for most LLM applications.\n\n### Implementing the `Environment` Object\n\nTo customize your own training environment, you need to write a file like `recipes.multiplayer_rl.guess_number.env`.\nWe have already implemented the abstract class `Env` for you, so you can focus on implementing the initial observation, transition function, and reward function.\n\nWe will start by explaining the `step` function, which is the \"core\" of the environment: it determines\n- how the LLM-generated action will influence the environment,\n- what is the reward of this action,\n- whether the game will end with this action, and\n- what is the LLM's next observation.\n\nIn our implementation, we first parse the integer tokens into messages, and access the LLM generation in Python string format with `message[\"content\"]`.\nThen we use a helper function `_get_user_turn` to compute what the user would say, and what the reward should be. This helper function is straightforward to implement, because it is plain Python without using special libraries.\n\n(What's the difference between a tokenizer and a renderer? A renderer performs one more step: it first converts messages (a list of dictionaries) into a raw string, which will then be converted to a list of integers by a tokenizer.)\n\nIn order to obtain the next observation that the LLM will see, we first update the conversation history `self.turns` by concatenating the messages that the policy and the LLM have just generated.\nThen ``self._obs`` will return the LLM's next observation, which is the \"tokenized conversation history\" provided by the renderer.\nWe end the implementation by returning `StepResult`, which contains the next observation, the reward, whether the game should stop, and `self.stop_condition` (which is usually the renderer's stop_sequences).\n\nBased on the knowledge above, it is straightforward to implement `initial_observation`.\nWe have now finished implementing the customized training environment for guessing numbers.\n\n### Constructing the Environment Object\n\nThe RL training config takes `GuessNumberDatasetBuilder` as an argument,\n\n* which constructs `GuessNumberDataset`,\n* which constructs `GuessNumberEnvGroupBuilder`,\n* which constructs a group of `GuessNumberEnv`.\n\nWhile there are a lot of classes involved, most of them are just boilerplate. Crucial information, such as the answer and the renderer, are all constructed in `GuessNumberDatasetBuilder.__call__`, which are eventually passed to construct `GuessNumberEnv`.\n\n### Next\n\nWe have covered a basic example of guessing numbers. For more sophisticated examples that involve more than one language model, feel free to read the README.md files under `recipes.multiplayer_rl.twenty_questions` and `recipes.multiplayer_rl.text_arena`.\n\n### References\n\n[1] Sutton, R. S., & Barto, A. G. (2018). Reinforcement learning: An introduction (2nd ed.). MIT Press. http://incompleteideas.net/book/the-book-2nd.html\n[2] Kaelbling, L. P., Littman, M. L., & Cassandra, A. R. (1998). Planning and acting in partially observable stochastic domains. Artificial Intelligence, 101(1–2), 99–134. https://doi.org/10.1016/S0004-3702(98)00023-X\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/guess_number/env.py",
    "content": "import random\nimport re\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\n\nimport chz\nfrom tinker import ModelInput\n\nfrom tinker_cookbook.completers import (\n    StopCondition,\n)\nfrom tinker_cookbook.renderers import Message, Renderer, ensure_text, get_renderer\nfrom tinker_cookbook.rl.types import (\n    Action,\n    Env,\n    EnvGroupBuilder,\n    RLDataset,\n    RLDatasetBuilder,\n    StepResult,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n_UPPER_BOUND = 1024\nSYSTEM_PROMPT = f\"\"\"\nYour job is to guess a integer between 0 and {_UPPER_BOUND}. You can output your guess with the format Guess: <guess> (without the angle brackets), and say nothing else. Then we will tell you whether you guess is too high, too low, or correct.\"\"\".strip()\nFAIL_TO_PARSE = \"Failed to parse. Please output in the format Guess: <guess> (without the angle brackets), and say nothing else.\"\nFORMAT_PENALTY = -1.0\nMAX_TURNS = 10\nRETURN_ON_FAIL = (Message(role=\"user\", content=FAIL_TO_PARSE), FORMAT_PENALTY)\n\n\nclass GuessNumberEnv(Env):\n    def __init__(self, gold_answer: int, renderer: Renderer):\n        self.system_prompt: Message = {\"role\": \"system\", \"content\": SYSTEM_PROMPT}\n        self.renderer: Renderer = renderer\n        self.turns: list[Message] = []\n        self.gold_answer: int = gold_answer\n\n    @property\n    def stop_condition(self) -> StopCondition:\n        return self.renderer.get_stop_sequences()\n\n    @property\n    def _obs(self) -> ModelInput:\n        \"\"\"Get the observation for the player in tokenized form\"\"\"\n        convo_for_player = [self.system_prompt] + self.turns\n        return self.renderer.build_generation_prompt(convo_for_player)\n\n    async def initial_observation(self) -> tuple[ModelInput, StopCondition]:\n        return self._obs, self.stop_condition\n\n    def _get_user_turn(self, action_text: str) -> tuple[Message, float]:\n        # parse the answer from the action text (i.e. LLM's guess)\n        match = re.match(r\"Guess: (.*)\", action_text)\n        maybe_answer = match.group(1) if match else None\n        try:\n            if maybe_answer is not None:\n                int_answer = int(maybe_answer)\n                if int_answer == self.gold_answer:\n                    text, reward = \"Correct\", 1.0\n                elif int_answer < self.gold_answer:\n                    text, reward = \"Too low\", 0.0\n                else:\n                    text, reward = \"Too high\", 0.0\n                return Message(role=\"user\", content=text), reward\n            else:\n                return RETURN_ON_FAIL\n        except ValueError:\n            return RETURN_ON_FAIL\n\n    async def step(self, action: Action) -> StepResult:\n        # step 1: parse the action tokens into a message\n        # this step is specific to our library, but usually templated, so you can just copy it.\n        (action_message, _parse_success) = self.renderer.parse_response(action)\n\n        # step 2: based on the string answer, we compute the reward and the user turn.\n        # This part is NOT templated, so you need to implement it. But it is plain python without using special libraries.\n        action_content = ensure_text(action_message[\"content\"])\n        user_turn, reward = self._get_user_turn(action_content)\n\n        # step 3: update the conversation history\n        self.turns.append({\"role\": \"player\", \"content\": action_content})\n        self.turns.append(user_turn)\n        episode_done = (reward == 1) or (len(self.turns) // 2 >= MAX_TURNS)\n\n        # step 4: return the step result\n        step_result = StepResult(\n            next_observation=self._obs,\n            next_stop_condition=self.stop_condition,\n            episode_done=episode_done,\n            reward=reward,\n            logs={\n                \"guess\": action_content,\n                \"feedback\": ensure_text(user_turn[\"content\"]),\n                \"target\": self.gold_answer,\n            },\n        )\n\n        return step_result\n\n\n@dataclass(frozen=True)\nclass GuessNumberEnvGroupBuilder(EnvGroupBuilder):\n    answer: int\n    renderer: Renderer\n    num_envs: int\n\n    async def make_envs(self) -> Sequence[Env]:\n        return [GuessNumberEnv(self.answer, self.renderer) for _ in range(self.num_envs)]\n\n\n# The dataset just indexes into the list of possible answers.\n\n\n@dataclass(frozen=True)\nclass GuessNumberDataset(RLDataset):\n    answers: Sequence[int]\n    renderer: Renderer\n    batch_size: int\n    group_size: int\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        return [\n            GuessNumberEnvGroupBuilder(\n                answer=self.answers[index * self.batch_size + i],\n                renderer=self.renderer,\n                num_envs=self.group_size,\n            )\n            for i in range(self.batch_size)\n        ]\n\n    def __len__(self) -> int:\n        return len(self.answers) // self.batch_size\n\n\n@chz.chz\nclass GuessNumberDatasetBuilder(RLDatasetBuilder):\n    batch_size: int\n    renderer_name: str\n    train_group_size: int\n    base_url: str | None = None\n    model_name: str\n    train_fraction: float = 0.9\n    test_group_size: int = 4\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset]:\n        player_renderer = get_renderer(self.renderer_name, get_tokenizer(self.model_name))\n        train_numbers, test_numbers = self._get_train_and_test_numbers()\n        assert self.batch_size <= len(train_numbers)\n        training_dataset = GuessNumberDataset(\n            answers=train_numbers,\n            renderer=player_renderer,\n            batch_size=self.batch_size,\n            group_size=self.train_group_size,\n        )\n        test_dataset = GuessNumberDataset(\n            answers=test_numbers,\n            renderer=player_renderer,\n            batch_size=len(test_numbers),  # test set only contains one batch\n            group_size=self.test_group_size,\n        )\n        return training_dataset, test_dataset\n\n    def _get_train_and_test_numbers(self) -> tuple[list[int], list[int]]:\n        rng = random.Random(0)\n        num_train_datapoints = int(_UPPER_BOUND * self.train_fraction)\n        shuffled_numbers = rng.sample(range(0, _UPPER_BOUND), _UPPER_BOUND)\n        train_numbers = shuffled_numbers[:num_train_datapoints]\n        test_numbers = shuffled_numbers[num_train_datapoints:]\n        return train_numbers, test_numbers\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/guess_number/train.py",
    "content": "import asyncio\nfrom datetime import datetime\n\nimport chz\n\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.recipes.multiplayer_rl.guess_number.env import GuessNumberDatasetBuilder\nfrom tinker_cookbook.rl import train\n\n\n@chz.chz\nclass CLIConfig:\n    model_name: str = \"Qwen/Qwen3-4B-Instruct-2507\"\n    renderer_name: str | None = None\n    group_size: int = 8\n    batch_size: int = 32\n    learning_rate: float = 3e-5\n    max_tokens: int = 64\n    eval_every: int = 5\n    save_every: int = 20\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    log_path: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\ndef build_config(cli_config: CLIConfig) -> train.Config:\n    model_name = cli_config.model_name\n    renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(\n        cli_config.model_name\n    )\n\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    run_name = f\"{model_name}-{cli_config.group_size}group-{cli_config.batch_size}batch-{cli_config.learning_rate}lr-{date_and_time}\"\n\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/guess-number/{run_name}\"\n\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    dataset_builder = GuessNumberDatasetBuilder(\n        batch_size=cli_config.batch_size,\n        model_name=model_name,\n        renderer_name=renderer_name,\n        train_group_size=cli_config.group_size,\n    )\n\n    return train.Config(\n        model_name=model_name,\n        renderer_name=renderer_name,\n        log_path=log_path,\n        dataset_builder=dataset_builder,\n        learning_rate=cli_config.learning_rate,\n        max_tokens=cli_config.max_tokens,\n        eval_every=cli_config.eval_every,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        max_steps=cli_config.max_steps,\n    )\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    config = build_config(cli_config)\n    # Avoid clobbering log dir from your previous run:\n    cli_utils.check_log_dir(\n        config.log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists\n    )\n    asyncio.run(train.main(config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/text_arena/README.md",
    "content": "# Learning Tic-Tac-Toe via Self-Play\n\n## Installation\n\n```bash\nuv pip install 'tinker-cookbook[multiplayer-rl] @ git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly'\n```\n\nMany research studies involve training several different language model agents jointly. We cover one simple example, where the language model learns to play tic-tac-toe with itself.\nWe show how to coordinate the steps of two *Environment* objects such that both the winning and the losing trajectory will be used to fine-tune the weights.\n\n```bash\npython -m tinker_cookbook.recipes.multiplayer_rl.text_arena.train\n```\n\nThe `test/env/all/reward/total` should increase from ~ -1.0 to >=0 in 40 steps.\n\n### Background\n\nThe TextArena [1] already implements an environment object where two players can specify which position to play using ``[0], [1], [2] ...`` in tic-tac-toe and compute how the board changes, the observation (prompt) for each language model player, and the final reward.\n\nHere's an example language model input:\n```\n[GAME] You are Player 0 in Tic Tac Toe.\nYour goal is to win three in a row (horizontally, vertically, or diagonally) on the board.\nOn your turn, you should select the square number (0-8) you want to put your mark in next.\nFor example, '[4]' places your mark in the center cell of the board.\n\nAs Player 0, you will be 'O', while your opponent is 'X'.\n\n[GAME] Current Board:\n\n 0 | 1 | 2\n---+---+---\n 3 | 4 | 5\n---+---+---\n 6 | 7 | 8\n\nAvailable Moves: '[0]', '[1]', '[2]', '[3]', '[4]', '[5]', '[6]', '[7]', '[8]'\n```\n\nIf the language model wants to play in the middle, it can output `[4]`.\n\n### Coordinators\n\nTraining an LLM to play against a fixed LLM is straightforward -- this is quite similar to our twenty-question example, where we sample the response from another LLM in the `Env.step` function.\nHowever, in this example, we want to train on both trajectories where the language model plays on each side.\nTherefore, in the `Env.step` function, we need to receive the opponent's action, which is generated in another trajectory in another `Environment` object.\nThis motivates the design of the `Coordinator` class, which passes the LLM-generated text between two `Environment` objects and synchronizes the two `Environment` objects to alternate taking steps.\n\nIn our implementation, the `TwoPlayerCoordinator` object is shared across two `Environment` objects, and it:\n- wraps the tic-tac-toe environment from the TextArena [1],\n- waits for a specific player's turn to begin, and\n- allows one player to `make_move` on the board, and notifies the other player that the move is complete.\n\nAs a result, in the `Environment.step` function, we can:\n- determine when to start the next move, since `TwoPlayerCoordinator` informs us when the opponent has finished.\n- compute the next observation, since `TwoPlayerCoordinator` passes the move from the opponent.\n\n### Extension\n\nMulti-agent training is a very active research direction with many different algorithm choices, e.g., debate [2], prover-verifier games [3], etc.\nWe hope Tinker can support the broader research community to explore these opportunities!\n\n\n\n### References\n\n[1] Guertler, L., Cheng, B., Yu, S., Liu, B., Choshen, L., & Tan, C. (2025). *TextArena*. arXiv preprint arXiv:2504.11442. https://arxiv.org/abs/2504.11442\n[2] Khan, A., Hughes, J., Valentine, D., Ruis, L., Sachan, K., Radhakrishnan, A., Grefenstette, E., Bowman, S. R., Rocktäschel, T., & Perez, E. (2024). Debating with More Persuasive LLMs Leads to More Truthful Answers. Proceedings of Machine Learning Research, 235, 23662-23733.\n[3] Kirchner, J. H., Chen, Y., Edwards, H., Leike, J., McAleese, N., & Burda, Y. (2024). Prover-verifier games improve legibility of LLM outputs. arXiv preprint arXiv:2407.13692. https://arxiv.org/abs/2407.13692\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/text_arena/env.py",
    "content": "\"\"\"TextArena TicTacToe environment for tinker RL.\"\"\"\n\nimport asyncio\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\nfrom typing import ClassVar\n\nimport chz\nimport tinker\nfrom tinker import types\n\ntry:\n    import textarena as ta\nexcept ImportError:\n    raise ImportError(\n        \"textarena is required for this recipe. \"\n        \"Install it with: uv pip install 'tinker-cookbook[multiplayer-rl] @ \"\n        \"git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly'\"\n    ) from None\n\nfrom tinker_cookbook.completers import StopCondition, TinkerMessageCompleter\nfrom tinker_cookbook.renderers import Message, Renderer, get_renderer, get_text_content\nfrom tinker_cookbook.rl.types import (\n    Action,\n    Env,\n    EnvGroupBuilder,\n    Observation,\n    RLDataset,\n    RLDatasetBuilder,\n    StepResult,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nSTOP_CONDITION = [\"]\\n\"]\nILLEGAL_MOVE_REWARD = -2.0\n\n\nclass TwoPlayerCoordinator:\n    \"\"\"Coordinates a single two player game between two players. See README.md in this folder for more details.\"\"\"\n\n    def __init__(self, shared_env: ta.Env):\n        self.shared_env = shared_env  # Should already be reset\n        self.condition = asyncio.Condition()\n        self.illegal_player_id: int | None = None\n\n    @property\n    def state(self) -> ta.State:\n        return self.shared_env.state  # type: ignore\n\n    @property\n    def current_player_id(self) -> int:\n        \"\"\"Get the current player ID from the environment state.\"\"\"\n        return self.state.current_player_id\n\n    @property\n    def game_done(self) -> bool:\n        \"\"\"Check if the game is done. Either the game state is done, or some player made an illegal move\"\"\"\n        return self.state.done or self.illegal_player_id is not None\n\n    @property\n    def rewards(self) -> dict | None:\n        \"\"\"Get rewards from the environment state.\"\"\"\n        return self.state.rewards\n\n    async def wait_across_env(self, player_id: int) -> None:\n        \"\"\"Wait until the opponent has finished their turn\"\"\"\n        async with self.condition:\n            await self.condition.wait_for(\n                lambda: self.current_player_id == player_id or self.game_done\n            )\n\n    async def make_move(self, player_id: int, move: str) -> None:\n        \"\"\"Make a move and notify waiting players.\"\"\"\n        async with self.condition:\n            # Ensure it's actually this player's turn\n            before_player_id = self.current_player_id\n\n            if not self.game_done and (self.current_player_id != player_id):\n                raise ValueError(\n                    f\"Not player {player_id}'s turn (current: {self.current_player_id})\"\n                )\n\n            done, _ = self.shared_env.step(move)\n\n            if done:\n                self.shared_env.close()\n            else:\n                # we will know that the move is illegal if the next player's id has not changed after the move\n                if self.current_player_id == before_player_id:\n                    self.illegal_player_id = before_player_id\n\n            # Notify all waiting players about the state change\n            self.condition.notify_all()\n\n\n@dataclass\nclass TwoPlayerEnv(Env):\n    \"\"\"Two player TextArena environment.\"\"\"\n\n    player_id: int  # 0 or 1\n    coordinator: TwoPlayerCoordinator\n    self_play: bool\n    renderer: Renderer\n    opponent_policy: TinkerMessageCompleter | None\n\n    def __post_init__(self):\n        assert self.self_play == (self.opponent_policy is None), (\n            \"If self_play is True, opponent_policy must be None\"\n        )\n\n    @property\n    def stop_condition(self) -> StopCondition:\n        return STOP_CONDITION  # TextArena envs look for action in square brackets\n\n    async def wait_for_turn(self) -> None:\n        \"\"\"If the game is not done, wait until the opponent's to finish playing their turn\"\"\"\n        if not self.coordinator.game_done:\n            if self.self_play:\n                await self.coordinator.wait_across_env(self.player_id)\n            else:\n                await self.opponent_step()\n\n    async def initial_observation(self) -> tuple[Observation, StopCondition]:\n        if self.player_id != 0:\n            await self.wait_for_turn()\n        return self.get_observation(), self.stop_condition\n\n    async def opponent_step(self) -> None:\n        \"\"\"When not self_play, the opponent policy takes a step on the shared environment\"\"\"\n        assert self.opponent_policy is not None\n        opponent_player_id, opponent_observation_str = self.coordinator.shared_env.get_observation()\n        assert isinstance(opponent_player_id, int) and isinstance(opponent_observation_str, str)\n        assert opponent_player_id == 1 - self.player_id, (\n            f\"Opponent player ID should be 1 - [the id of the policy player], {opponent_player_id=}, {self.player_id=}\"\n        )\n        opponent_convo: list[Message] = [{\"role\": \"user\", \"content\": opponent_observation_str}]\n        opponent_response = await self.opponent_policy(opponent_convo)\n        opponent_action_content = get_text_content(opponent_response)\n        await self.coordinator.make_move(opponent_player_id, opponent_action_content)\n\n    async def step(self, action: Action) -> StepResult:\n        \"\"\"Take a step in the environment.\"\"\"\n        if self.coordinator.game_done:\n            return self.get_done_step()\n        assert self.coordinator.current_player_id == self.player_id, \"Not the current player's turn\"\n\n        # make a move ...\n        action_message: Message = self.renderer.parse_response(action)[0]\n        action_content = get_text_content(action_message)\n        await self.coordinator.make_move(self.player_id, action_content)\n\n        # we wait here rather than the beginning of this function,\n        # because we want to know whether this player still needs to make a future move, and give that information to StepResult.\n        # it is possible that this player's move is legal and hence make the game continue, kicking off another step in this step; however, the opponent ends the game.\n        await self.wait_for_turn()\n        return StepResult(\n            reward=self.compute_reward(),\n            episode_done=self.coordinator.game_done,\n            next_observation=self.get_observation(),\n            next_stop_condition=self.stop_condition,\n            metrics={},\n        )\n\n    def get_done_step(self) -> StepResult:\n        return StepResult(\n            reward=0.0,\n            episode_done=True,\n            next_observation=types.ModelInput.empty(),\n            next_stop_condition=STOP_CONDITION,\n            metrics={},\n        )\n\n    def compute_reward(self) -> float:\n        if self.coordinator.illegal_player_id == self.player_id:\n            return ILLEGAL_MOVE_REWARD\n        if self.coordinator.rewards:\n            return self.coordinator.rewards[self.player_id]\n        return 0.0\n\n    def get_observation(self) -> types.ModelInput:\n        current_player_id, observation_str = self.coordinator.shared_env.get_observation()\n        if not self.coordinator.game_done:\n            assert isinstance(current_player_id, int) and isinstance(observation_str, str)\n            assert current_player_id == self.player_id, (\n                f\"Observation should be for the current player, obs: {observation_str}, current_player_id: {current_player_id}, player_id: {self.player_id}\"\n            )\n        return self.renderer.build_generation_prompt([{\"role\": \"user\", \"content\": observation_str}])\n\n\n@dataclass\nclass TwoPlayerEnvGroupBuilder(EnvGroupBuilder):\n    \"\"\"Builder for groups of two player TextArena environments sharing the same game.\"\"\"\n\n    game_name: str\n    renderer: Renderer\n    num_envs: int\n    self_play: bool\n    num_players: ClassVar[int] = 2\n    opponent_policy: TinkerMessageCompleter | None = None\n\n    async def make_envs(self) -> Sequence[Env]:\n        \"\"\"Create a group of environments sharing the same TextArena game.\"\"\"\n        if self.num_envs % 2 != 0:\n            raise ValueError(\"this env requires an even number of environments (players)\")\n\n        def _construct_coordinator() -> TwoPlayerCoordinator:\n            \"\"\"During training, the coordinator performs necessary blocking/synchronization, so that the policys can take turns to make moves on the shared environment, across different Environment objects\"\"\"\n            shared_env = ta.make(env_id=self.game_name)\n            shared_env.reset(num_players=self.num_players)\n            return TwoPlayerCoordinator(shared_env=shared_env)\n\n        envs = []\n        for _ in range(self.num_envs // 2):\n            if self.self_play:\n                coordinator = _construct_coordinator()\n                # if self_play, then we need to share the same coordinator across all environments\n                coordinators = [coordinator for _ in range(self.num_players)]\n            else:\n                # if not self_play, we can just create a different coordinator for each environment\n                coordinators = [_construct_coordinator() for _ in range(self.num_players)]\n\n            envs += [\n                TwoPlayerEnv(\n                    player_id=i,\n                    coordinator=coordinators[i],\n                    renderer=self.renderer,\n                    self_play=self.self_play,\n                    opponent_policy=self.opponent_policy,\n                )\n                for i in range(2)\n            ]\n        return envs\n\n\nclass TwoPlayerTextArenaDataset(RLDataset):\n    \"\"\"Dataset for TextArena environments.\"\"\"\n\n    def __init__(self, batch_size: int, builder: TwoPlayerEnvGroupBuilder, num_datapoints: int):\n        self.batch_size = batch_size\n        self.builder = builder\n        self.num_datapoints = num_datapoints\n        assert self.num_datapoints % self.builder.num_players == 0, (\n            \"num_datapoints must be divisible by num_players\"\n        )\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        return [\n            self.builder\n            for i in range(self.batch_size // self.builder.num_players)\n            if (index * self.batch_size + self.builder.num_players * i) < self.num_datapoints\n        ]\n\n    def __len__(self) -> int:\n        return self.num_datapoints // self.batch_size\n\n\n@chz.chz\nclass TwoPlayerTextArenaDatasetBuilder(RLDatasetBuilder):\n    batch_size: int\n    num_train_datapoints: int\n    num_test_datapoints: int\n    base_url: str | None = None\n    model_name: str\n    game_name: str\n    renderer_name: str\n\n    def _construct_opponent_policy(self, renderer: Renderer) -> TinkerMessageCompleter:\n        \"\"\"Play against a fixed policy during testing.\"\"\"\n        service_client = tinker.ServiceClient(base_url=self.base_url)\n        sampling_client = service_client.create_sampling_client(base_model=self.model_name)\n        return TinkerMessageCompleter(\n            sampling_client=sampling_client,\n            renderer=renderer,\n            max_tokens=64,\n            stop_condition=STOP_CONDITION,\n        )\n\n    async def __call__(self) -> tuple[TwoPlayerTextArenaDataset, TwoPlayerTextArenaDataset | None]:\n        \"\"\"Build the dataset for training and testing.\"\"\"\n        renderer = get_renderer(self.renderer_name, get_tokenizer(self.model_name))\n\n        # The training dataset performs self-play\n        train_builder = TwoPlayerEnvGroupBuilder(\n            game_name=self.game_name,\n            renderer=renderer,\n            num_envs=2,\n            self_play=True,\n        )\n        train_dataset = TwoPlayerTextArenaDataset(\n            batch_size=self.batch_size,\n            builder=train_builder,\n            num_datapoints=self.num_train_datapoints,\n        )\n\n        # The testing dataset plays against a fixed policy, constructed by self._opponent_policy\n        test_builder = TwoPlayerEnvGroupBuilder(\n            game_name=self.game_name,\n            renderer=renderer,\n            num_envs=2,\n            self_play=False,\n            opponent_policy=self._construct_opponent_policy(renderer),\n        )\n        test_dataset = TwoPlayerTextArenaDataset(\n            batch_size=self.num_test_datapoints,\n            builder=test_builder,\n            num_datapoints=self.num_test_datapoints,\n        )\n        return train_dataset, test_dataset\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/text_arena/train.py",
    "content": "import asyncio\nfrom datetime import datetime\n\nimport chz\n\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.recipes.multiplayer_rl.text_arena.env import TwoPlayerTextArenaDatasetBuilder\nfrom tinker_cookbook.rl import train\n\n\n@chz.chz\nclass CLIConfig:\n    model_name: str = \"Qwen/Qwen3-4B-Instruct-2507\"\n    renderer_name: str | None = None\n    game_name: str = \"TicTacToe-v0\"\n    batch_size: int = 512\n    num_train_datapoints: int = 131072\n    num_test_datapoints: int = 128\n    learning_rate: float = 3e-5\n    max_tokens: int = 64\n    eval_every: int = 5\n    save_every: int = 20\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    log_path: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\ndef build_config(cli_config: CLIConfig) -> train.Config:\n    model_name = cli_config.model_name\n    renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(\n        cli_config.model_name\n    )\n\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    run_name = f\"{model_name}-{cli_config.game_name.lower().replace('-v0', '')}-{cli_config.batch_size}batch-{cli_config.learning_rate}lr-{date_and_time}\"\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/text-arena/{run_name}\"\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    dataset_builder = TwoPlayerTextArenaDatasetBuilder(\n        batch_size=cli_config.batch_size,\n        model_name=model_name,\n        game_name=cli_config.game_name,\n        num_train_datapoints=cli_config.num_train_datapoints,\n        num_test_datapoints=cli_config.num_test_datapoints,\n        renderer_name=renderer_name,\n    )\n\n    return train.Config(\n        model_name=model_name,\n        renderer_name=renderer_name,\n        log_path=log_path,\n        dataset_builder=dataset_builder,\n        learning_rate=cli_config.learning_rate,\n        max_tokens=cli_config.max_tokens,\n        eval_every=cli_config.eval_every,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        max_steps=cli_config.max_steps,\n    )\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    config = build_config(cli_config)\n    # Avoid clobbering log dir from your previous run:\n    cli_utils.check_log_dir(\n        config.log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists\n    )\n    asyncio.run(train.main(config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/twenty_questions/README.md",
    "content": "# Playing Twenty Questions Against Another Language Model\n\n```bash\npython -m tinker_cookbook.recipes.multiplayer_rl.twenty_questions.train\n```\n\nThe `test/env/all/reward/total` should increase from ~10% to ~20% after 20 steps.\n\n### Background: Twenty Questions\n\nThis game involves a *player* and an *answerer*. The answerer has a *secret word* in mind, and the player needs to ask yes/no questions to guess it. Here is an example where the secret word is apple:\n\n> **Player**: Is it an animal?\n> **Answerer**: No\n>\n> **Player**: Is it a plant?\n> **Answerer**: Yes\n>\n> **Player**: Does it grow on a tree?\n> **Answerer**: Yes\n>\n> **Player**: Guess: Apple\n> **Answerer**: Yes\n\nTo use Tinker to train LLMs to play twenty questions, we mainly need to implement the `TwentyQuestionsEnv` class.\n\n### Implementing a Training Environment\n\nEach environment object has exactly one secret word; it determines what the conversation looks like based on the player’s (policy’s) questions (actions).\n\nThe most important logic happens within the class method `TwentyQuestionsEnv.step` in [env.py](./env.py). This `step` function takes in:\n\n* An action from the language model, which is a sequence of int tokens that can be parsed into a natural language question, e.g., “`Is it a plant?`“\n\nThis function outputs a `StepResult`, which contains:\n\n* Reward: 1 if the player guesses the keyword “Apple“, 0 otherwise. (see `TwentyQuestionsEnv._compute_reward(content)`)\n* Whether this episode should end: either when the player correctly guesses the secret word Apple, or the player has asked more than 20 questions.\n* next_stop_condition: it’s usually the stop token of the policy language model.\n* next_observation: what the player (policy) sees in the next turn, in token space. In twenty questions, it involves the conversation history, the last question it asked, and the answer to the last question.\n\nTo automatically compute the next_observation — what the answerer would respond — we need to sample from another language model to play the role of answerer. We will go through it next.\n\n### Sampling from Another Language Model in `Environment.step`\n\nHere's what the answerer receives as input and produces as output:\n\n> **Answerer Input:** “Answer yes/no questions about your secret word. Your secret word is apple. Question: Is it a plant?“\n>\n> **Answerer Output:** “Yes“.\n\nThe Environment class will be responsible for preparing the Input (see `_convo_for_answerer`), which will then be handed to the answerer for sampling. You can sample from:\n\n* a third-party LLM library that you are already familiar with, or\n* the `MessageCompleter` object that we have implemented, which is an API interface similar to OpenAI Chat Completion. You can sample from any model we support, or a model you have just trained (or are still training) on Tinker!\n\n### Extensions\n\nOur demo is simple, involving only one LLM other than the policy and a very simple data format (a single string as an answer). One can easily add more LLMs into this environment, e.g., using a language model as a judge and sampling from `judge_message_completer`. One can also easily extend it to more complicated datapoint types, including rubrics  [1] for LLM-judges [2], conversation starters, reference answers, or private information for human simulators [3, 4]. There are many other possibilities, and we look forward to seeing what you come up with using Tinker!\n\n### Next\n\nIn this example, we play against a static language model answerer, which does not update during training.\nIn recipes.multiplayer_rl.text_arena, we will demonstrate an example (tic-tac-toe), which updates the weights of both players in a game.\n\n[1] Checklists Are Better Than Reward Models For Aligning Language Models\nViswanathan, V., Sun, Y., Ma, S., Kong, X., Cao, M., Neubig, G., & Wu, T. (2025). Checklists are better than reward models for aligning language models. arXiv. https://arxiv.org/abs/2507.18624\n[2] Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena\nZheng, L., Chiang, W.-L., Sheng, Y., Zhuang, S., Wu, Z., Zhuang, Y., Lin, Z., Li, Z., Li, D., Xing, E. P., Zhang, H., Gonzalez, J. E., & Stoica, I. (2023). Judging LLM-as-a-judge with MT-bench and chatbot arena. arXiv. https://arxiv.org/abs/2306.05685\n[3] SWEET-RL: Training Multi-Turn LLM Agents on Collaborative Reasoning Tasks\nZhou, Y., Jiang, S., Tian, Y., Weston, J., Levine, S., Sukhbaatar, S., & Li, X. (2025). SWEET-RL: Training multi-turn LLM agents on collaborative reasoning tasks. arXiv. https://arxiv.org/abs/2503.15478\n[4] CollabLLM: From Passive Responders to Active Collaborators\nWu, S., Galley, M., Peng, B., Cheng, H., Li, G., Dou, Y., Cai, W., Zou, J., Leskovec, J., & Gao, J. (2025). CollabLLM: From passive responders to active collaborators. arXiv. https://arxiv.org/abs/2502.00640\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/twenty_questions/common_english_nouns.txt",
    "content": "actor\nadult\nairport\nant\napple\naunt\nbaby\nbag\nball\nballoon\nbanana\nbasket\nbear\nbed\nbee\nbeef\nbeer\nbell\nbicycle\nbiscuit\nbird\nblanket\nbook\nbottle\nbox\nboy\nbread\nbrother\nbuilding\nbutter\ncake\ncamera\ncandle\ncar\ncarpenter\ncarrot\ncastle\ncat\ncave\ncheese\nchicken\nchild\ncity\nclock\ncloud\ncoat\ncoffee\ncomputer\ncookie\ncorn\ncousin\ncow\ncup\ndad\ndaughter\ndog\ndoor\nduck\neagle\negg\nfather\nfire\nfish\nflag\nfloor\nforest\nfork\nfriend\ngarden\ngarlic\nglass\ngrandmother\ngrandfather\nguitar\nguest\nhammer\nhat\nhero\nhill\nhorse\nhospital\nhotel\nhouse\nhusband\nice\nice cream\nisland\njuice\njudge\nkangaroo\nkey\nkid\nking\nkitten\nknife\nladder\nlake\nlamp\nlady\nlion\nlock\nman\nmilk\nmirror\nmother\nmountain\nmouse\nnurse\norange\nparent\npark\npassenger\npen\npencil\nperson\nphone\npicture\npillow\npizza\npotato\nqueen\nradio\nrain\nrice\nring\nriver\nroad\nrunner\nsalad\nsalt\nsandwich\nscissors\nschool\nsea\nsheep\nshoe\nshop\nsinger\nsister\nsnow\nsoldier\nsoup\nstar\nstone\nstorm\nstreet\nstudent\nsugar\nsun\nsupermarket\ntable\ntea\nteacher\ntomato\ntown\ntoy\ntrain\nvillage\nwatch\nwater\nwife\nwind\nwindow\nwine\nwolf\nwoman\nwriter\nyogurt\nzebra\nzoo\nzucchini\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/twenty_questions/env.py",
    "content": "import functools\nimport random\nimport re\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport chz\nimport tinker\nfrom tinker import ModelInput\n\nfrom tinker_cookbook import model_info\nfrom tinker_cookbook.completers import (\n    MessageCompleter,\n    StopCondition,\n    TinkerMessageCompleter,\n)\nfrom tinker_cookbook.model_info import get_recommended_renderer_name\nfrom tinker_cookbook.renderers import Message, Renderer, get_renderer, get_text_content\nfrom tinker_cookbook.rl.types import (\n    Action,\n    Env,\n    EnvGroupBuilder,\n    RLDataset,\n    RLDatasetBuilder,\n    StepResult,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.utils import logtree\n\nANSWERER_SYSTEM_PROMPT = \"\"\"\nYou are the answerer in a game of 20 questions. You should only ever respond with 'yes' or 'no'. Your secret word is {answer}. If the other player guesses it with Guess: <answer>, respond with 'yes' only if the answer is precisely your secret word.\n\"\"\".strip()\n\nPLAYER_SYSTEM_PROMPT = \"\"\"\nYou are the player in a game of 20 questions. You will ask a series of yes/no questions to the answerer. You will win if you can guess the answer in 20 questions or less. You will lose if you ask more than 20 questions. To guess the answer, write a line of the form 'Guess: <answer>' (without the angle brackets). The answer is always a single word -- don't use articles like 'a' or 'the'. Your questions should be one line, and less than 20 words.\n\"\"\".strip()\n\n\nclass TwentyQuestionsEnv(Env):\n    def __init__(self, answerer: MessageCompleter, answer: str, renderer: Renderer):\n        self.answerer: MessageCompleter = answerer\n        self.answer: str = answer\n        self.sys_for_answerer: Message = {\n            \"role\": \"system\",\n            \"content\": ANSWERER_SYSTEM_PROMPT.format(answer=answer),\n        }\n        self.sys_for_player: Message = {\n            \"role\": \"system\",\n            \"content\": PLAYER_SYSTEM_PROMPT,\n        }\n        self.renderer: Renderer = renderer\n        self.turns: list[Message] = []\n\n    @property\n    def stop_condition(self) -> StopCondition:\n        return self.renderer.get_stop_sequences()\n\n    def _convo_for_player(self) -> list[Message]:\n        \"\"\"Conversation from the player's perspective.\"\"\"\n        game_role_to_chat_role = {\"answerer\": \"user\", \"player\": \"assistant\"}\n        return [self.sys_for_player] + [\n            {\"role\": game_role_to_chat_role[turn[\"role\"]], \"content\": turn[\"content\"]}\n            for turn in self.turns\n        ]\n\n    def _get_obs(self) -> ModelInput:\n        \"\"\"Get the observation for the player in tokenized form\"\"\"\n        return self.renderer.build_generation_prompt(self._convo_for_player())\n\n    def _convo_for_answerer(self) -> list[Message]:\n        \"\"\"Conversation from the answerer's perspective.\"\"\"\n        game_role_to_chat_role = {\"answerer\": \"assistant\", \"player\": \"user\"}\n        return (\n            [self.sys_for_answerer]\n            + [\n                {\"role\": game_role_to_chat_role[turn[\"role\"]], \"content\": turn[\"content\"]}\n                for turn in self.turns[\n                    -1:\n                ]  # show the answerer only the last turn because the response to each player turn should be independent of the previous turns\n            ]\n        )\n\n    async def initial_observation(self) -> tuple[ModelInput, StopCondition]:\n        return self._get_obs(), self.stop_condition\n\n    def _compute_reward(self, content: str) -> float:\n        \"\"\"\n        Returns 1.0 if the content contains the answer, 0.0 otherwise.\n        \"\"\"\n        match = re.match(r\"Guess: (.*)\", content)\n        maybe_answer = match.group(1) if match else None\n        content_contains_answer = (maybe_answer is not None) and (\n            maybe_answer.lower() == self.answer.lower()\n        )\n        return 1.0 if content_contains_answer else 0.0\n\n    async def step(self, action: Action) -> StepResult:\n        \"\"\"\n        In one step,\n        1. The environment accepts an action from the player (a message containin a question or a guess).\n        2. We obtain the response from the answerer and update the conversation history in self.turns.\n        3. We calculate the reward and decide whether to end the episode.\n        4. We return these information, along with the next observation built from the updated conversation history.\n        \"\"\"\n\n        # step 1: accepts the action from the player (policy)\n        (action_message, _parse_success) = self.renderer.parse_response(action)\n        self.turns.append({\"role\": \"player\", \"content\": action_message[\"content\"]})\n\n        # step 2: the answerer responds\n        answer_message = await self.answerer(self._convo_for_answerer())\n        self.turns.append({\"role\": \"answerer\", \"content\": answer_message[\"content\"]})\n\n        # step 3: we calculate the reward and decide whether to end the episode.\n        # the episode ends if the player guessed the answer or the player asked more than 20 questions\n        action_content = get_text_content(action_message)\n        reward = self._compute_reward(action_content)\n        episode_done = (reward == 1) or (len(self.turns) // 2 >= 20)\n\n        # Log the turn\n        turn_num = len(self.turns) // 2\n        logtree.log_text(f\"Turn {turn_num} - Player: {action_message['content']}\")\n        logtree.log_text(f\"Turn {turn_num} - Answerer: {answer_message['content']}\")\n        if episode_done:\n            logtree.log_text(\n                f\"Game Over - Secret: {self.answer}, Won: {'✓' if reward == 1 else '✗'}, Turns: {turn_num}\"\n            )\n\n        # step 4: we return the next observation, reward, and whether the episode is done\n        step_result = StepResult(\n            next_observation=self._get_obs(),\n            next_stop_condition=self.stop_condition,\n            episode_done=episode_done,\n            reward=reward,\n        )\n\n        return step_result\n\n\n# The EnvGroupBuilder is trivial: just return a list of copies of the same environment.\n\n\n@functools.cache\ndef _load_words_from_file() -> list[str]:\n    module_dir = Path(__file__).parent\n    file_path = module_dir / \"common_english_nouns.txt\"\n\n    rng = random.Random(0)\n    with open(file_path) as f:\n        words = [line.strip() for line in f.readlines()]\n    rng.shuffle(words)\n    return words\n\n\n@dataclass(frozen=True)\nclass TwentyQuestionsEnvGroupBuilder(EnvGroupBuilder):\n    answerer: MessageCompleter\n    answer: str\n    renderer: Renderer\n    num_envs: int\n\n    async def make_envs(self) -> Sequence[Env]:\n        return [\n            TwentyQuestionsEnv(self.answerer, self.answer, self.renderer)\n            for _ in range(self.num_envs)\n        ]\n\n\n# The dataset just indexes into the list of possible answers.\n\n\n@dataclass(frozen=True)\nclass TwentyQuestionsDataset(RLDataset):\n    answerer: MessageCompleter\n    answers: Sequence[str]\n    renderer: Renderer\n    batch_size: int\n    group_size: int\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        return [\n            TwentyQuestionsEnvGroupBuilder(\n                answerer=self.answerer,\n                answer=self.answers[index * self.batch_size + i],\n                renderer=self.renderer,\n                num_envs=self.group_size,\n            )\n            for i in range(self.batch_size)\n        ]\n\n    def __len__(self) -> int:\n        return len(self.answers) // self.batch_size\n\n\n@chz.chz\nclass TwentyQuestionsDatasetBuilder(RLDatasetBuilder):\n    batch_size: int\n    model_name_for_tokenizer: str\n    renderer_name: str\n    group_size: int\n    base_url: str | None = None\n    num_epochs: int = 1\n    test_group_size: int = 32\n    answerer_base_model: str = \"meta-llama/Llama-3.1-8B-Instruct\"\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset]:\n        service_client = tinker.ServiceClient(base_url=self.base_url)\n        answerer = self._construct_answer_completer(service_client)\n        train_words, test_words = self._get_train_and_test_words()\n        player_renderer = get_renderer(\n            self.renderer_name, get_tokenizer(self.model_name_for_tokenizer)\n        )\n        assert self.batch_size <= len(train_words)\n        training_dataset = TwentyQuestionsDataset(\n            answerer=answerer,\n            answers=train_words,\n            renderer=player_renderer,\n            batch_size=self.batch_size,\n            group_size=self.group_size,\n        )\n        test_dataset = TwentyQuestionsDataset(\n            answerer=answerer,\n            answers=test_words,\n            renderer=player_renderer,\n            batch_size=len(test_words),  # test set only contains one batch\n            group_size=self.test_group_size,\n        )\n        return training_dataset, test_dataset\n\n    def _construct_answer_completer(self, service_client: tinker.ServiceClient) -> MessageCompleter:\n        if self.answerer_base_model.startswith(\"Qwen/Qwen3\"):\n            answerer_renderer_name = \"qwen3_disable_thinking\"\n        else:\n            answerer_renderer_name = model_info.get_recommended_renderer_name(\n                self.answerer_base_model\n            )\n        answerer_tokenizer = get_tokenizer(self.answerer_base_model)\n        answerer_renderer = get_renderer(answerer_renderer_name, answerer_tokenizer)\n        answerer_sampling_client = service_client.create_sampling_client(\n            base_model=self.answerer_base_model\n        )\n        answerer = TinkerMessageCompleter(\n            sampling_client=answerer_sampling_client, renderer=answerer_renderer, max_tokens=5\n        )\n        return answerer\n\n    def _get_train_and_test_words(self) -> tuple[list[str], list[str]]:\n        words = _load_words_from_file()\n        num_test = min(len(words) // 5, 100)\n        train_words = words[:-num_test]\n        test_words = words[-num_test:]\n        train_words = train_words * self.num_epochs\n        return train_words, test_words\n\n\ndef construct_minimal_20q_env(answer: str) -> TwentyQuestionsEnv:\n    answerer_model = \"meta-llama/Llama-3.1-8B-Instruct\"\n\n    service_client = tinker.ServiceClient()\n    answerer_sampling_client = service_client.create_sampling_client(base_model=answerer_model)\n    answerer = TinkerMessageCompleter(\n        sampling_client=answerer_sampling_client,\n        renderer=get_renderer(\n            get_recommended_renderer_name(answerer_model), get_tokenizer(answerer_model)\n        ),\n        max_tokens=5,\n    )\n    policy_renderer = get_renderer(\n        get_recommended_renderer_name(answerer_model), get_tokenizer(answerer_model)\n    )  # this argument is not actually used and is a placeholder\n    env = TwentyQuestionsEnv(answerer, answer, policy_renderer)\n    return env\n"
  },
  {
    "path": "tinker_cookbook/recipes/multiplayer_rl/twenty_questions/train.py",
    "content": "import asyncio\nfrom datetime import datetime\n\nimport chz\n\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.recipes.multiplayer_rl.twenty_questions.env import (\n    TwentyQuestionsDatasetBuilder,\n)\nfrom tinker_cookbook.rl import train\n\n\n@chz.chz\nclass CLIConfig:\n    model_name: str = \"Qwen/Qwen3-4B-Instruct-2507\"\n    renderer_name: str | None = None\n    group_size: int = 8\n    num_epochs: int = 100\n    batch_size: int = 64\n    learning_rate: float = 3e-5\n    max_tokens: int = 20\n    eval_every: int = 5\n    save_every: int = 20\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    log_path: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\ndef build_config(cli_config: CLIConfig) -> train.Config:\n    model_name = cli_config.model_name\n    renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(\n        cli_config.model_name\n    )\n\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    run_name = f\"{model_name}-{cli_config.group_size}group-{cli_config.batch_size}batch-{cli_config.learning_rate}lr-{date_and_time}\"\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/twenty-questions-rl/{run_name}\"\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    dataset_builder = TwentyQuestionsDatasetBuilder(\n        batch_size=cli_config.batch_size,\n        model_name_for_tokenizer=model_name,\n        renderer_name=renderer_name,\n        group_size=cli_config.group_size,\n        num_epochs=cli_config.num_epochs,\n    )\n\n    return train.Config(\n        model_name=model_name,\n        renderer_name=renderer_name,\n        log_path=log_path,\n        dataset_builder=dataset_builder,\n        learning_rate=cli_config.learning_rate,\n        max_tokens=cli_config.max_tokens,\n        eval_every=cli_config.eval_every,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        max_steps=cli_config.max_steps,\n    )\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    config = build_config(cli_config)\n    # Avoid clobbering log dir from your previous run:\n    cli_utils.check_log_dir(\n        config.log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists\n    )\n    asyncio.run(train.main(config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/README.md",
    "content": "# Learning from Preferences\n\nMany applications involve learning from preferences beyond scalar rewards. We provide a few examples here:\n\n1. [Shorter](./shorter/): we introduce the `PairwisePreferenceRLDatasetBuilder` abstraction and walk through a simple example that trains a model to generate shorter responses.\n2. [RLHF](./rlhf/): we walk through the standard RLHF pipeline from [1, 2]. This pipeline involves three stages: supervised fine-tuning, reward model learning, and reinforcement learning.\n3. [DPO](./dpo/): we optimize for human preferences using the Direct Preference Optimization algorithm [3], which requires a custom loss function.\n\n**References:**\n1. Bai, Y., Kadavath, S., Kundu, S., Askell, A., Kernion, J., Jones, A., Chen, A., Goldie, A., Mirhoseini, A., McKinnon, C., Chen, C., Olsson, C., Olah, C., Hernandez, D., Drain, D., Ganguli, D., Li, D., Tran-Johnson, E., Perez, E., Kerr, J., Mueller, J., Ladish, J., Landau, J., Ndousse, K., Lukošiūtė, K., Lovitt, L., Sellitto, M., Elhage, N., Schiefer, N., Mercado, N., ... Kaplan, J. (2022). Training a helpful and harmless assistant with reinforcement learning from human feedback. arXiv. https://arxiv.org/abs/2204.05862\n2. Ouyang, L., Wu, J., Jiang, X., Almeida, D., Wainwright, C. L., Mishkin, P., Zhang, C., Agarwal, S., Slama, K., Ray, A., Schulman, J., Hilton, J., Kelton, F., Miller, L., Simens, M., Askell, A., Welinder, P., Christiano, P. F., Leike, J., & Lowe, R. (2022). Training language models to follow instructions with human feedback. Advances in Neural Information Processing Systems, 35, 27730-27744. https://arxiv.org/abs/2203.02155\n3. Rafailov, R., Sharma, A., Mitchell, E., Ermon, S., Manning, C. D., & Finn, C. (2023). Direct preference optimization: Your language model is secretly a reward model. arXiv. https://arxiv.org/abs/2305.18290\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/datasets.py",
    "content": "import logging\nimport re\nfrom typing import cast\n\nimport chz\nimport datasets\nimport pandas as pd\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.preference.preference_datasets import ComparisonDatasetBuilder\nfrom tinker_cookbook.preference.types import (\n    Comparison,\n    LabeledComparison,\n)\n\nlogger = logging.getLogger(__name__)\n\n\n# ============================================================================\n# Helper Functions\n# ============================================================================\n\n\ndef _hhh_parse_conversation(text: str) -> list[renderers.Message]:\n    \"\"\"Parse conversation text into message list format.\"\"\"\n    messages = []\n\n    # Split by Human: or Assistant: and capture the delimiter\n    parts = re.split(r\"(Human:|Assistant:)\", text)\n\n    # Skip the first part if it's empty (text starts with a delimiter)\n    if not parts[0].strip():\n        parts = parts[1:]\n\n    # Process parts in pairs: (delimiter, content)\n    for i in range(0, len(parts), 2):\n        if i + 1 < len(parts):\n            delimiter = parts[i].strip()\n            content = parts[i + 1].strip()\n\n            if delimiter == \"Human:\":\n                messages.append({\"role\": \"user\", \"content\": content})\n            elif delimiter == \"Assistant:\":\n                messages.append({\"role\": \"assistant\", \"content\": content})\n\n    return messages\n\n\ndef hhh_example_to_comparison(example: dict[str, str]) -> LabeledComparison | None:\n    \"\"\"Process a single preference pair into the new format.\"\"\"\n    chosen = _hhh_parse_conversation(example[\"chosen\"])\n    rejected = _hhh_parse_conversation(example[\"rejected\"])\n    if len(chosen) != len(rejected):\n        # Ran into at least one malformatted example like this\n        return None\n    match_bool_list = [\n        chosen_msg == rejected_msg\n        for chosen_msg, rejected_msg in zip(chosen, rejected, strict=True)\n    ]\n    if match_bool_list != [True] * (len(match_bool_list) - 1) + [False]:\n        # Ran into at least one malformatted example like this\n        return None\n    comparison = Comparison(\n        prompt_conversation=chosen[:-1],\n        completion_A=[chosen[-1]],\n        completion_B=[rejected[-1]],\n    )\n    return LabeledComparison(comparison=comparison, label=\"A\")\n\n\ndef _arena_parse_conversation(conversation: list) -> list[renderers.Message] | None:\n    \"\"\"Parse arena conversation to message format.\"\"\"\n    messages = []\n    for msg in conversation:\n        assert isinstance(msg, dict)\n        role = msg[\"role\"]\n        content_list = msg[\"content\"]\n\n        # Extract text content only\n        text_parts = []\n        for item in content_list:\n            if isinstance(item, dict) and item[\"type\"] == \"text\":\n                text = item.get(\"text\", \"\")\n                text_parts.append(text)\n            else:\n                logger.info(f\"Skipping arena conversation with non-text content: {msg}\")\n                return None\n\n        if text_parts:\n            content = \" \".join(text_parts)\n            if role == \"user\":\n                messages.append({\"role\": \"user\", \"content\": content})\n            elif role == \"assistant\":\n                messages.append({\"role\": \"assistant\", \"content\": content})\n\n    return messages\n\n\n# ============================================================================\n# Concrete Dataset Implementations\n# ============================================================================\n\n\n@chz.chz\nclass Tulu38BComparisonBuilder(ComparisonDatasetBuilder):\n    \"\"\"Tulu 3.8B preference dataset comparison builder.\"\"\"\n\n    def get_train_and_test_datasets(self) -> tuple[datasets.Dataset, datasets.Dataset | None]:\n        dataset = datasets.load_dataset(\n            \"allenai/llama-3.1-tulu-3-8b-preference-mixture\", split=\"train\"\n        )\n        dataset = cast(datasets.Dataset, dataset)\n        dataset = dataset.shuffle(seed=0)\n        test_dataset = dataset.take(1024)\n        train_dataset = dataset.skip(1024)\n        return train_dataset, test_dataset\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        instruction = example[\"prompt\"]\n        chosen_response = example[\"chosen\"][1][\"content\"]\n        rejected_response = example[\"rejected\"][1][\"content\"]\n\n        prompt_conversation: list[renderers.Message] = [{\"role\": \"user\", \"content\": instruction}]\n\n        comparison = Comparison(\n            prompt_conversation=prompt_conversation,\n            completion_A=[{\"role\": \"assistant\", \"content\": chosen_response}],\n            completion_B=[{\"role\": \"assistant\", \"content\": rejected_response}],\n        )\n        return LabeledComparison(comparison=comparison, label=\"A\")\n\n\n@chz.chz\nclass HHHComparisonBuilder(ComparisonDatasetBuilder):\n    \"\"\"HHH dataset comparison builder.\"\"\"\n\n    test_size: int = 1024\n\n    def get_train_and_test_datasets(self) -> tuple[datasets.Dataset, datasets.Dataset | None]:\n        dataset = datasets.load_dataset(\"Anthropic/hh-rlhf\")\n        dataset = cast(datasets.DatasetDict, dataset)\n        train_dataset = dataset[\"train\"].shuffle(seed=0)\n        test_dataset = dataset[\"test\"].shuffle(seed=0).take(self.test_size)\n        return train_dataset, test_dataset\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        return hhh_example_to_comparison(example)\n\n\n@chz.chz\nclass HelpSteer3ComparisonBuilder(ComparisonDatasetBuilder):\n    \"\"\"HelpSteer3 dataset comparison builder.\"\"\"\n\n    def get_train_and_test_datasets(self) -> tuple[datasets.Dataset, datasets.Dataset | None]:\n        dataset = datasets.load_dataset(\"nvidia/HelpSteer3\", \"preference\")\n        dataset = cast(datasets.DatasetDict, dataset)\n        train_dataset = dataset[\"train\"].shuffle(seed=0)\n        test_dataset = dataset[\"validation\"].shuffle(seed=0).take(1024)\n        return train_dataset, test_dataset\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        context = example[\"context\"]\n        response1 = example[\"response1\"]\n        response2 = example[\"response2\"]\n        overall_preference = example[\"overall_preference\"]\n\n        # Skip ties\n        if overall_preference == 0:\n            return None\n\n        # Convert context to message format\n        comparison = Comparison(\n            prompt_conversation=context,\n            completion_A=[{\"role\": \"assistant\", \"content\": response1}],\n            completion_B=[{\"role\": \"assistant\", \"content\": response2}],\n        )\n        return LabeledComparison(\n            comparison=comparison, label=\"A\" if overall_preference < 0 else \"B\"\n        )\n\n\n@chz.chz\nclass UltraFeedbackComparisonBuilder(ComparisonDatasetBuilder):\n    \"\"\"UltraFeedback dataset comparison builder.\"\"\"\n\n    def get_train_and_test_datasets(self) -> tuple[datasets.Dataset, datasets.Dataset | None]:\n        dataset = datasets.load_dataset(\n            \"argilla/ultrafeedback-binarized-preferences\", split=\"train\"\n        )\n        dataset = cast(datasets.Dataset, dataset)\n        dataset = dataset.shuffle(seed=0)\n        test_dataset = dataset.take(1024)\n        train_dataset = dataset.skip(1024)\n        return train_dataset, test_dataset\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        instruction = example[\"instruction\"]\n        chosen_response = example[\"chosen_response\"]\n        rejected_response = example[\"rejected_response\"]\n\n        prompt_conversation: list[renderers.Message] = [{\"role\": \"user\", \"content\": instruction}]\n\n        comparison = Comparison(\n            prompt_conversation=prompt_conversation,\n            completion_A=[{\"role\": \"assistant\", \"content\": chosen_response}],\n            completion_B=[{\"role\": \"assistant\", \"content\": rejected_response}],\n        )\n        return LabeledComparison(comparison=comparison, label=\"A\")\n\n\n@chz.chz\nclass ArenaComparisonBuilder(ComparisonDatasetBuilder):\n    \"\"\"Arena dataset comparison builder.\"\"\"\n\n    def get_train_and_test_datasets(self) -> tuple[datasets.Dataset, datasets.Dataset | None]:\n        dataset = datasets.load_dataset(\"lmarena-ai/arena-human-preference-140k\", split=\"train\")\n        dataset = cast(datasets.Dataset, dataset)\n\n        dataset = dataset.shuffle(seed=0)\n        test_dataset = dataset.take(1024)\n        train_dataset = dataset.skip(1024)\n        return train_dataset, test_dataset\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        winner = example[\"winner\"]\n\n        # Skip ties or invalid winners\n        if winner not in [\"model_a\", \"model_b\"]:\n            # print(f\"Skipping arena example with invalid winner: {winner}\")\n            return None\n\n        conversation_a = _arena_parse_conversation(example[\"conversation_a\"])\n        conversation_b = _arena_parse_conversation(example[\"conversation_b\"])\n\n        # Skip if conversations are empty or malformed\n        if not conversation_a or not conversation_b:\n            logger.info(\"Skipping arena example with empty conversations\")\n            return None\n\n        # The conversations should have same prompt (all messages except last assistant response)\n        # Check that both have at least a user message and assistant response\n        if len(conversation_a) < 2 or len(conversation_b) < 2:\n            logger.info(\"Skipping arena example with too few messages\")\n            return None\n\n        # Verify last message is assistant in both\n        if conversation_a[-1][\"role\"] != \"assistant\" or conversation_b[-1][\"role\"] != \"assistant\":\n            logger.info(\"Skipping arena example with non-assistant last message\")\n            return None\n\n        comparison = Comparison(\n            prompt_conversation=conversation_a[0:1],\n            completion_A=conversation_a[1:],\n            completion_B=conversation_b[1:],\n        )\n\n        return LabeledComparison(comparison=comparison, label=\"A\" if winner == \"model_a\" else \"B\")\n\n\n@chz.chz\nclass HelpSteer2ComparisonBuilder(ComparisonDatasetBuilder):\n    \"\"\"HelpSteer2 dataset comparison builder.\"\"\"\n\n    def get_train_and_test_datasets(self) -> tuple[datasets.Dataset, datasets.Dataset | None]:\n        dataset = datasets.load_dataset(\"nvidia/HelpSteer2\", split=\"train\")\n        dataset = cast(datasets.Dataset, dataset)\n\n        # Create pairs by grouping examples with same prompt\n        prompt_to_responses: dict[str, list] = {}\n        for i in range(len(dataset)):\n            example = dataset[i]\n            prompt = example[\"prompt\"]\n            if prompt not in prompt_to_responses:\n                prompt_to_responses[prompt] = []\n            prompt_to_responses[prompt].append(example)\n\n        # Create comparison pairs from examples with same prompt\n        comparisons = []\n        for prompt, responses in prompt_to_responses.items():\n            if len(responses) >= 2:\n                # Sort by helpfulness score to create preferences\n                responses.sort(key=lambda x: x[\"helpfulness\"], reverse=True)\n                # Take best vs worst if significant difference\n                if responses[0][\"helpfulness\"] > responses[-1][\"helpfulness\"]:\n                    comparisons.append(\n                        {\n                            \"prompt\": prompt,\n                            \"chosen_response\": responses[0][\"response\"],\n                            \"rejected_response\": responses[-1][\"response\"],\n                            \"helpfulness_diff\": responses[0][\"helpfulness\"]\n                            - responses[-1][\"helpfulness\"],\n                        }\n                    )\n\n        # Convert to dataset\n        df = pd.DataFrame(comparisons)\n        dataset = datasets.Dataset.from_pandas(df)\n        dataset = dataset.shuffle(seed=0)\n\n        test_dataset = dataset.take(min(1024, len(dataset) // 10))\n        train_dataset = dataset.skip(min(1024, len(dataset) // 10))\n        return train_dataset, test_dataset\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        prompt = example[\"prompt\"]\n        chosen_response = example[\"chosen_response\"]\n        rejected_response = example[\"rejected_response\"]\n\n        prompt_conversation: list[renderers.Message] = [{\"role\": \"user\", \"content\": prompt}]\n\n        comparison = Comparison(\n            prompt_conversation=prompt_conversation,\n            completion_A=[{\"role\": \"assistant\", \"content\": chosen_response}],\n            completion_B=[{\"role\": \"assistant\", \"content\": rejected_response}],\n        )\n        return LabeledComparison(comparison=comparison, label=\"A\")\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/dpo/README.md",
    "content": "# Direct Preference Optimization\n\nPlease check our [doc](https://tinker-docs.thinkingmachines.ai/preferences/dpo-guide) for background on DPO.\n\nHere is an example command:\n```\npython -m tinker_cookbook.recipes.preference.dpo.train \\\n    log_path=/tmp/dpo-hhh-experiment \\\n    model_name=meta-llama/Llama-3.2-1B \\\n    dataset=hhh \\\n    renderer_name=role_colon \\\n    learning_rate=1e-5 \\\n    dpo_beta=0.1\n```\n\nAfter 50 steps, you should expect training metrics like:\n```\n                   Step 50\n┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓\n┃ Metric                         ┃ Value     ┃\n┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩\n│ accuracy                       │ 0.568627  │\n│ batch_time                     │ 27.953704 │\n│ chosen_reward                  │ 0.053621  │\n│ dpo_loss                       │ 0.683825  │\n│ learning_rate                  │ 0.000009  │\n│ margin                         │ 0.002147  │\n│ num_pairs                      │ 255       │\n│ num_tokens                     │ 112638    │\n│ progress                       │ 0.081210  │\n│ rejected_reward                │ 0.032152  │\n│ test/nll                       │ 1.871778  │\n└────────────────────────────────┴───────────┘\n```\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/dpo/train.py",
    "content": "\"\"\"\nBasic CLI for training with Direct Preference Optimization (DPO). It only supports a few datasets and configuration options; if you want to do something more complicated, please write a new script and call the train_dpo.main function directly.\n\"\"\"\n\nfrom datetime import datetime\n\nimport chz\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.preference import train_dpo\nfrom tinker_cookbook.preference.dpo_datasets import (\n    DPODatasetBuilderFromComparisons,\n)\nfrom tinker_cookbook.recipes.preference.datasets import (\n    HelpSteer3ComparisonBuilder,\n    HHHComparisonBuilder,\n    UltraFeedbackComparisonBuilder,\n)\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilder, ChatDatasetBuilderCommonConfig\nfrom tinker_cookbook.utils.lr_scheduling import LRSchedule\n\n\n@chz.chz\nclass CLIConfig:\n    model_name: str = \"meta-llama/Llama-3.2-1B\"\n    dataset: str = \"hhh\"  # or path like tinker_cookbook.preference.preference_datasets:HHHBuilder\n    load_checkpoint_path: str | None = None\n    renderer_name: str | None = None\n\n    # Training parameters\n    learning_rate: float = 1e-5\n    lr_schedule: LRSchedule = \"linear\"\n    dpo_beta: float = 0.1\n    max_length: int | None = 8192\n    batch_size: int = 256\n\n    # Logging parameters\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    # Service configuration\n    base_url: str | None = None\n\n    # DPO-specific parameters\n    reference_model_name: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\ndef get_dataset_builder(\n    dataset: str,\n    model_name: str,\n    renderer_name: str,\n    max_length: int | None,\n    batch_size: int,\n) -> ChatDatasetBuilder:\n    \"\"\"Get the appropriate dataset builder for DPO training.\"\"\"\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=model_name,\n        renderer_name=renderer_name,\n        max_length=max_length,\n        batch_size=batch_size,\n    )\n\n    if dataset == \"hhh\":\n        return DPODatasetBuilderFromComparisons(\n            common_config=common_config, comparison_builder=HHHComparisonBuilder()\n        )\n    elif dataset == \"helpsteer3\":\n        return DPODatasetBuilderFromComparisons(\n            common_config=common_config, comparison_builder=HelpSteer3ComparisonBuilder()\n        )\n    elif dataset == \"ultrafeedback\":\n        return DPODatasetBuilderFromComparisons(\n            common_config=common_config, comparison_builder=UltraFeedbackComparisonBuilder()\n        )\n    else:\n        raise ValueError(f\"Unknown dataset: {dataset}\")\n\n\ndef cli_main(cli_config: CLIConfig):\n    \"\"\"Main CLI function that builds the full config and calls the training function.\"\"\"\n    # Build full config\n    renderer_name = checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    model_name = cli_config.model_name.replace(\"/\", \"-\")\n    run_name = f\"{cli_config.dataset}-{model_name}-{cli_config.learning_rate}lr-{cli_config.batch_size}batch-{date_and_time}\"\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/dpo/{run_name}\"\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    config = train_dpo.Config(\n        log_path=log_path,\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        dataset_builder=get_dataset_builder(\n            cli_config.dataset,\n            cli_config.model_name,\n            renderer_name,\n            cli_config.max_length,\n            cli_config.batch_size,\n        ),\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        evaluator_builders=[],\n        learning_rate=cli_config.learning_rate,\n        lr_schedule=cli_config.lr_schedule,\n        dpo_beta=cli_config.dpo_beta,\n        base_url=cli_config.base_url,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        reference_model_name=cli_config.reference_model_name,\n        max_steps=cli_config.max_steps,\n    )\n\n    train_dpo.main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    cli_main(cli_config)\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/rlhf/README.md",
    "content": "# RLHF Pipeline\n\n```bash\npython -m tinker_cookbook.recipes.preference.rlhf.rlhf_pipeline\n```\n\nThere are three stages:\n1. Policy SFT stage: this stage is short, and `test/nll` should decrease from 1.99 to 1.92 in 20 steps.\n2. Reward model SFT stage: this stage is longer, and `test/nll` should drastically decrease from 7 to around 0.7 in the first 40 steps, slowly decrease to 0.6 at around step 300, and converge to around 0.55 in 600 steps. This stage needs to finish before the next stage.\n3. Policy RL stage: `test/win_rate` should increase from ~40% to ~70% in 100 steps.\n\n### Stage 1 and 2: Supervised Fine-Tuning\n\nThe first two stages are supervised fine-tuning, which is relatively straightforward (see `recipes.sl_basic` for an example). In the first stage, we perform supervised fine-tuning to initialize the policy on the `no_robots` dataset from Hugging Face; in the second stage, we perform supervised fine-tuning to learn the reward model on the `HHH` dataset from Anthropic.\n\n### Stage 3: RL against the Reward Model\n\nIn the third stage, we initialize with the policy produced by the first stage, and optimize against the reward model learned in the second stage.\nAs before, we need to implement a `PreferenceModelBuilder` and a `ComparisonBuilder`.\nIn our implementation, we use `PreferenceModelBuilderFromChatRenderer` for the former, and `HHHComparisonBuilder` for the latter.\nNow we can optimize against a learned reward model!\n\n### Next\n\nWe include another way to learn from preferences, DPO, which requires a custom loss function.\n\n1. Rajani, N., Tunstall, L., Beeching, E., Lambert, N., Rush, A. M., & Wolf, T. (2023). No Robots. [https://huggingface.co/datasets/HuggingFaceH4/no_robots](https://huggingface.co/datasets/HuggingFaceH4/no_robots)\n2. Bai, Y., Kadavath, S., Kundu, S., Askell, A., Kernion, J., Jones, A., Chen, A., Goldie, A., Mirhoseini, A., McKinnon, C., Chen, C., Olsson, C., Olah, C., Hernandez, D., Drain, D., Ganguli, D., Li, D., Tran-Johnson, E., Perez, E., Kerr, J., Mueller, J., Ladish, J., Landau, J., Ndousse, K., Lukošiūtė, K., Lovitt, L., Sellitto, M., Elhage, N., Schiefer, N., Mercado, N., ... Kaplan, J. (2022). Training a helpful and harmless assistant with reinforcement learning from human feedback. arXiv. https://arxiv.org/abs/2204.05862\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/rlhf/rlhf_pipeline.py",
    "content": "import asyncio\nimport logging\nfrom pathlib import Path\n\nimport chz\n\nfrom tinker_cookbook import checkpoint_utils, model_info\nfrom tinker_cookbook.preference.comparison_policy_evaluator import ComparisonEvaluator\nfrom tinker_cookbook.preference.preference_datasets import ChatDatasetBuilderFromComparisons\nfrom tinker_cookbook.preference.types import PreferenceModelBuilderFromChatRenderer\nfrom tinker_cookbook.recipes.chat_sl.chat_datasets import NoRobotsBuilder\nfrom tinker_cookbook.recipes.preference.datasets import HHHComparisonBuilder\nfrom tinker_cookbook.renderers import TrainOnWhat\nfrom tinker_cookbook.rl import preference_envs, train\nfrom tinker_cookbook.supervised import train as supervised_train\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilderCommonConfig\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass CLIConfig:\n    base_model: str = \"meta-llama/Llama-3.2-3B\"\n    short_name: str = \"llama3b\"\n    run_sft: bool = True\n    run_rm: bool = True\n    run_rl: bool = True\n    wandb_project: str | None = None\n    wandb_name: str | None = \"rlhf\"\n    lora_rank: int = 64\n    max_length: int = 16384\n    batch_size: int = 256\n\n    sft_learning_rate: float = 2e-4\n    rm_learning_rate: float = 3e-4\n    rl_learning_rate: float = 1e-5\n    rl_max_tokens: int = 1024\n    rl_group_size: int = 4\n\n    save_every: int = 100\n    eval_every: int = 20\n\n    # Logtree configuration - number of groups to log per iteration (0 = disable)\n    num_groups_to_log: int = 4\n\n    max_steps: int | None = None\n\n\ndef sft_stage(\n    log_path: str,\n    base_model: str,\n    wandb_project: str | None,\n    wandb_name: str | None,\n    lora_rank: int,\n    batch_size: int,\n    learning_rate: float,\n    max_length: int,\n    save_every: int,\n    eval_every: int,\n    max_steps: int | None = None,\n):\n    \"\"\"\n    Train base policy on NoRobots dataset\n    \"\"\"\n    # Create renderer for the model\n    renderer_name = model_info.get_recommended_renderer_name(base_model)\n\n    # Create common config for the dataset builder\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=base_model,\n        renderer_name=renderer_name,\n        max_length=max_length,\n        batch_size=batch_size,\n        train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES,\n    )\n\n    # Use NoRobots dataset for SFT\n    dataset_builder = NoRobotsBuilder(common_config=common_config)\n\n    # Create training config\n    config = supervised_train.Config(\n        log_path=log_path,\n        model_name=base_model,\n        renderer_name=renderer_name,\n        dataset_builder=dataset_builder,\n        evaluator_builders=[],  # Could add evaluators here\n        num_epochs=1,\n        learning_rate=learning_rate,\n        lr_schedule=\"linear\",\n        save_every=save_every,\n        eval_every=eval_every,\n        lora_rank=lora_rank,\n        wandb_project=wandb_project,\n        wandb_name=f\"{wandb_name}-sft\",\n        max_steps=max_steps,\n    )\n\n    # Run training\n    asyncio.run(supervised_train.main(config))\n\n\ndef train_rm(\n    log_path: str,\n    base_model: str,\n    wandb_project: str | None,\n    wandb_name: str | None,\n    lora_rank: int,\n    batch_size: int,\n    learning_rate: float,\n    max_length: int,\n    save_every: int,\n    eval_every: int,\n    max_steps: int | None = None,\n):\n    \"\"\"Train reward model using Anthropic HHH preference comparisons.\"\"\"\n    # Use HHH comparison builder for Anthropic data\n    comparison_builder = HHHComparisonBuilder()\n\n    # Get renderer name for the model\n    renderer_name = model_info.get_recommended_renderer_name(base_model)\n\n    # Create common config for the dataset builder\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=base_model,\n        renderer_name=renderer_name,\n        max_length=max_length,\n        batch_size=batch_size,\n    )\n\n    # Create the dataset builder that wraps comparisons with rendering\n    dataset_builder = ChatDatasetBuilderFromComparisons(\n        common_config=common_config, comparison_builder=comparison_builder\n    )\n\n    # Create training config\n    config = supervised_train.Config(\n        log_path=log_path,\n        model_name=base_model,\n        renderer_name=renderer_name,\n        dataset_builder=dataset_builder,\n        evaluator_builders=[],  # Could add evaluators here\n        num_epochs=1,\n        learning_rate=learning_rate,\n        lr_schedule=\"linear\",\n        save_every=save_every,\n        eval_every=eval_every,\n        lora_rank=lora_rank,\n        wandb_project=wandb_project,\n        wandb_name=f\"{wandb_name}-rm\",\n        max_steps=max_steps,\n    )\n\n    # Run training\n    asyncio.run(supervised_train.main(config))\n\n\nasync def train_rl(\n    log_path: str,\n    sft_log_path: str,\n    rm_log_path: str,\n    base_model: str,\n    wandb_project: str | None,\n    wandb_name: str | None,\n    lora_rank: int,\n    group_size: int,\n    batch_size: int,\n    learning_rate: float,\n    max_tokens: int,\n    save_every: int,\n    eval_every: int,\n    num_groups_to_log: int = 4,\n    max_steps: int | None = None,\n):\n    \"\"\"Train policy using RL with prompts from Anthropic HHH data.\"\"\"\n    # Get checkpoints from previous stages\n    sft_checkpoint_dict = checkpoint_utils.get_last_checkpoint(sft_log_path)\n    rm_checkpoint_dict = checkpoint_utils.get_last_checkpoint(rm_log_path)\n\n    if sft_checkpoint_dict is None:\n        raise ValueError(f\"No SFT checkpoint found in {sft_log_path}\")\n    if rm_checkpoint_dict is None:\n        raise ValueError(f\"No RM checkpoint found in {rm_log_path}\")\n\n    sft_checkpoint = sft_checkpoint_dict.state_path\n    rm_weights_path = rm_checkpoint_dict.sampler_path\n\n    # Use HHH comparison builder for prompts\n    comparison_builder = HHHComparisonBuilder()\n    renderer_name = model_info.get_recommended_renderer_name(base_model)\n\n    preference_model_builder = PreferenceModelBuilderFromChatRenderer(\n        renderer_name=renderer_name,\n        model_name=base_model,\n        rm_weights_path=rm_weights_path,\n    )\n\n    rl_dataset_builder = preference_envs.PairwisePreferenceRLDatasetBuilder(\n        comparison_builder=comparison_builder,\n        policy_renderer_name=renderer_name,\n        policy_model_name=base_model,\n        preference_model_builder=preference_model_builder,\n        batch_size=batch_size,\n        group_size=group_size,\n        tournament_pattern=preference_envs.TournamentPattern.ALL_PAIRS_BOTH_WAYS,\n    )\n\n    def get_evaluator_builder() -> ComparisonEvaluator:\n        comparison_builder_eval = HHHComparisonBuilder(test_size=256)\n        _, test_dataset = comparison_builder_eval.get_train_and_test_datasets()\n        assert test_dataset is not None\n        test_labeled_comparisons = [\n            comparison_builder_eval.example_to_labeled_comparison(example)  # type: ignore\n            for example in test_dataset\n        ]\n        test_comparisons = [lc.comparison for lc in test_labeled_comparisons if lc is not None]\n        return ComparisonEvaluator(\n            preference_model_builder=preference_model_builder,\n            comparisons=test_comparisons,\n            renderer_name=renderer_name,\n            model_name_for_tokenizer=base_model,\n        )\n\n    config = train.Config(\n        model_name=base_model,\n        renderer_name=renderer_name,\n        dataset_builder=rl_dataset_builder,\n        load_checkpoint_path=sft_checkpoint,\n        learning_rate=learning_rate,\n        max_tokens=max_tokens,\n        log_path=log_path,\n        evaluator_builders=[get_evaluator_builder],\n        wandb_project=wandb_project,\n        wandb_name=f\"{wandb_name}-rl\",\n        lora_rank=lora_rank,\n        save_every=save_every,\n        eval_every=eval_every,\n        num_groups_to_log=num_groups_to_log,\n        max_steps=max_steps,\n    )\n    await train.main(config)\n\n\ndef cli_main(cli_config: CLIConfig):\n    log_path_root = Path(f\"/tmp/tinker-examples/rlhf-{cli_config.short_name}\")\n    sft_log_path = str(log_path_root / \"sft\")\n    rm_log_path = str(log_path_root / \"rm\")\n    rl_log_path = str(log_path_root / \"rl\")\n    if cli_config.run_sft:\n        sft_stage(\n            sft_log_path,\n            cli_config.base_model,\n            cli_config.wandb_project,\n            cli_config.wandb_name,\n            cli_config.lora_rank,\n            cli_config.batch_size,\n            cli_config.sft_learning_rate,\n            cli_config.max_length,\n            cli_config.save_every,\n            cli_config.eval_every,\n            max_steps=cli_config.max_steps,\n        )\n    if cli_config.run_rm:\n        train_rm(\n            rm_log_path,\n            cli_config.base_model,\n            cli_config.wandb_project,\n            cli_config.wandb_name,\n            cli_config.lora_rank,\n            cli_config.batch_size,\n            cli_config.rm_learning_rate,\n            cli_config.max_length,\n            cli_config.save_every,\n            cli_config.eval_every,\n            max_steps=cli_config.max_steps,\n        )\n    if cli_config.run_rl:\n        asyncio.run(\n            train_rl(\n                rl_log_path,\n                sft_log_path,\n                rm_log_path,\n                cli_config.base_model,\n                cli_config.wandb_project,\n                cli_config.wandb_name,\n                cli_config.lora_rank,\n                cli_config.rl_group_size,\n                cli_config.batch_size,\n                cli_config.rl_learning_rate,\n                cli_config.rl_max_tokens,\n                cli_config.save_every,\n                cli_config.eval_every,\n                cli_config.num_groups_to_log,\n                max_steps=cli_config.max_steps,\n            )\n        )\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    cli_main(cli_config)\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/shorter/README.md",
    "content": "# Generating Shorter Responses via Comparisons\n\n```bash\npython -m tinker_cookbook.recipes.preference.shorter.train\n```\n\n`ac_tokens_per_turn` should drop significantly after 40 steps. The policy generates significantly shorter responses.\n\n### Using the `PairwisePreferenceRLDatasetBuilder` class\n\nWe implement the `PairwisePreferenceRLDatasetBuilder` class to make it easier to learn from preference pairs, rather than scalar rewards for an individual trajectory. The key objects you need to implement are: (1) PreferenceModelBuilder, and (2) ComparisonBuilder.\n\n**PreferenceModelBuilder** will build a *PreferenceModel* when called (via its `__call__()` method), which determines what responses are preferred. Concretely, `PreferenceModel.__call__`\n- accepts a `Comparison` object, which contains (1) `prompt_conversation`: a list of input messages that the policy model receives, and (2) `completion_A` and `completion_B`, each a list of messages that the policy model generates.\n- returns a floating point number; a larger number means that `completion_B` is better.\n\n**ComparisonBuilder** will be used in our code to create a list of `Comparison` objects. We need to implement two functions\n- `get_train_and_test_datasets`: which returns training and test Hugging Face `Dataset` objects\n- `example_to_labeled_comparison`: which converts each datapoint (a `dict` in the `Dataset` object) to a `Comparison` object.\n\nNote that `completion_A` and `completion_B` will NOT be used during training, and only `completion_A` will be used during `Eval`.\n\n### Implementation of This Simple Example\n\nWe build a simple preference model that prefers shorter responses in `shorter.env.PreferenceModelShorter`. Then we implement a simple wrapper that builds it in `ShorterPreferenceModelBuilder`. To implement the `ComparisonBuilder` object, we create dummy datasets, where the `input_messages` field is a constant `Who are you?`.\n\n### Next:\n\nWe will go through an example of the RLHF pipeline, which heavily relies on the `PairwisePreferenceRLDatasetBuilder` and `ComparisonBuilder` abstraction.\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/shorter/env.py",
    "content": "import chz\nfrom datasets import Dataset\n\nfrom tinker_cookbook.preference.preference_datasets import ComparisonDatasetBuilder\nfrom tinker_cookbook.preference.types import (\n    Comparison,\n    LabeledComparison,\n    PreferenceModel,\n    PreferenceModelBuilder,\n)\nfrom tinker_cookbook.renderers import Message\n\nCONVO_PREFIX: list[Message] = [{\"role\": \"user\", \"content\": \"Who are you?\"}]\nDUMMY_COMPLETION: list[Message] = [\n    {\n        \"role\": \"assistant\",\n        \"content\": \"Hello thre! I am a large language model. How can I assist you? Feel free to ask me anything.\",\n    }\n]\nDUMMY_COMPARISON: Comparison = Comparison(\n    prompt_conversation=CONVO_PREFIX,\n    completion_A=DUMMY_COMPLETION,\n    completion_B=DUMMY_COMPLETION,\n)\nDUMMY_DATASET: Dataset = Dataset.from_list([{\"id\": None}] * 1024)\n\n\nclass PreferenceModelShorter(PreferenceModel):\n    \"\"\"\n    A dummy preference model that always prefers a shorter response\n    \"\"\"\n\n    def _get_completion_length(self, completion: list[Message]) -> int:\n        char_count = 0\n        for message in completion:\n            char_count += len(message[\"content\"])\n        return char_count\n\n    async def __call__(self, comparison: Comparison) -> float:\n        length_a = self._get_completion_length(comparison.completion_A)\n        length_b = self._get_completion_length(comparison.completion_B)\n        if length_a > length_b:\n            return 1.0\n        elif length_b > length_a:\n            return -1.0\n        else:\n            return 0.0\n\n\n@chz.chz\nclass ShorterComparisonBuilder(ComparisonDatasetBuilder):\n    def get_train_and_test_datasets(self) -> tuple[Dataset, Dataset | None]:\n        return DUMMY_DATASET, None\n\n    def example_to_labeled_comparison(self, example: dict) -> LabeledComparison | None:\n        return LabeledComparison(comparison=DUMMY_COMPARISON, label=\"Tie\")\n\n\n@chz.chz\nclass ShorterPreferenceModelBuilder(PreferenceModelBuilder):\n    def __call__(self) -> PreferenceModel:\n        return PreferenceModelShorter()\n"
  },
  {
    "path": "tinker_cookbook/recipes/preference/shorter/train.py",
    "content": "import asyncio\nfrom datetime import datetime\n\nimport chz\n\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.recipes.preference.shorter.env import (\n    ShorterComparisonBuilder,\n    ShorterPreferenceModelBuilder,\n)\nfrom tinker_cookbook.rl import train\nfrom tinker_cookbook.rl.preference_envs import PairwisePreferenceRLDatasetBuilder\n\n\n@chz.chz\nclass CLIConfig:\n    model_name: str = \"Qwen/Qwen3-4B-Instruct-2507\"\n    renderer_name: str | None = None\n\n    # Training parameters\n    batch_size: int = 32\n    group_size: int = 16\n    learning_rate: float = 3e-5\n    max_tokens: int = 64\n    eval_every: int = 5\n\n    # Logging parameters\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\ndef cli_main(cli_config: CLIConfig):\n    model_name = cli_config.model_name\n    renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(model_name)\n\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    model_tag = model_name.replace(\"/\", \"-\")\n    run_name = f\"shorter-{model_tag}-{cli_config.batch_size}batch-{cli_config.group_size}group-{cli_config.learning_rate}lr-{date_and_time}\"\n\n    log_path = cli_config.log_path or f\"/tmp/tinker-examples/shorter/{run_name}\"\n    wandb_name = cli_config.wandb_name or run_name\n\n    comparison_builder = ShorterComparisonBuilder()\n    dataset_builder = PairwisePreferenceRLDatasetBuilder(\n        comparison_builder=comparison_builder,\n        batch_size=cli_config.batch_size,\n        policy_renderer_name=renderer_name,\n        policy_model_name=model_name,\n        group_size=cli_config.group_size,\n        preference_model_builder=ShorterPreferenceModelBuilder(),\n    )\n\n    config = train.Config(\n        model_name=model_name,\n        renderer_name=renderer_name,\n        log_path=log_path,\n        dataset_builder=dataset_builder,\n        learning_rate=cli_config.learning_rate,\n        max_tokens=cli_config.max_tokens,\n        eval_every=cli_config.eval_every,\n        compute_post_kl=True,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        max_steps=cli_config.max_steps,\n    )\n\n    # Avoid clobbering log dir from your previous run:\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n    asyncio.run(train.main(config))\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    cli_main(cli_config)\n"
  },
  {
    "path": "tinker_cookbook/recipes/prompt_distillation/README.md",
    "content": "# Prompt Distillation\n\nPrompt Distillation -- also known as context distillation [1,2] -- is a training method that can \"make an LLM internalize the prompt into its parameters\".\nIn this method, the model is fine-tuned to behave as if it had been provided with a long and complex prompt, even without actually accessing it.\n\nFor example, we want to internalize the following target prompt $p$:\n\n`Classify the language of the provided text into these labels: en, fr, zh, ja ...`\n\nAfter prompt distillation, the LLM will respond with only the language label after receiving a query without seeing the prompt $p$, e.g.,\n```\nQuery: 一生、バンドしてくれる？\nResponse: ja\n```\n\nAt a high level, this method involves two stages:\n1. **Creating data for distillation**: A teacher language model uses $p$ to generate responses $r$ on a set of queries $q$; i.e. $r \\sim \\text{teacher}(\\cdot|p, q)$\n2. **Training the student model**: A student model is fine-tuned to predict the responses $r$ to the query $q$ but without accessing $p$, hence learning to behave as if the target prompt is in its context; i.e. $\\text{student}(\\cdot | q)$ should predict $r$\n\n## Example\n\nThe Tinker Cookbook provides a prompt distillation recipe tailored for a language classification task. The objective is straightforward: given a text query, the model should predict a two-character code corresponding to the language of the input. The set of possible labels is:\n```\nar (Arabic), de (German), el (Greek), en (English), es (Spanish), fr (French), hi (Hindi), ru (Russian), tr (Turkish), ur (Urdu), vi (Vietnamese), zh (Chinese - Simplified), ot (Other/Unknown).\n```\n\nThe recipe in [`create_data.py`](create_data.py) also includes handling strategies for inputs containing code, numerical content, or multiple languages.\n\nIn the example below, the same model (`Qwen/Qwen3-30B-A3B`) is used as both teacher and student, though in general they need not be identical.\n\n---\n\n### Step 1: Generate Training Data\n\nGenerate prompt distillation data using the teacher model with [`create_data.py`](create_data.py):\n\n```bash\nmkdir -p /tmp/tinker-datasets\npython -m tinker_cookbook.recipes.prompt_distillation.create_data \\\n  output_file=/tmp/tinker-datasets/prompt_distillation_lang.jsonl\n```\n\nThis command will:\n- Use the configured teacher model to generate language classification examples\n- Save the distilled dataset to the specified output file\n- Create diverse training examples suitable for student model fine-tuning\n\n### Step 2: Train the Student Model\n\nFine-tune a student model on the distillation data using [`train.py`](train.py):\n\n```bash\npython -m tinker_cookbook.recipes.prompt_distillation.train\n```\n\nThe training script will:\n- Load the generated distillation dataset\n- Apply optimized training configurations\n- Fine-tune the student model for language classification\n\n### Step 3: Test Your Model\n\nOnce training is complete, you can test your distilled model by sampling from the trained model to verify its performance on language classification tasks.\n\n## Advanced Configuration\n\nThe prompt distillation recipe can be customized for different scenarios:\n\n- **Teacher model selection**: Choose different base models based on your requirements\n- **Sampling strategies**: Adjust temperature and other generation parameters\n- **Data volume**: Scale the number of generated examples based on your needs\n- **Training hyperparameters**: Fine-tune learning rates and other training settings\n\n[1] Askell, A., Bai, Y., Chen, A., Drain, D., Ganguli, D., Henighan, T., Jones, A., Joseph, N., Mann, B., DasSarma, N., Elhage, N., Hatfield-Dodds, Z., Hernandez, D., Kernion, J., Ndousse, K., Olsson, C., Amodei, D., Brown, T., Clark, J., McCandlish, S., Olah, C., & Kaplan, J. (2021). A general language assistant as a laboratory for alignment. arXiv preprint arXiv:2112.00861.\n\n[2] Snell, C., Klein, D., & Zhong, R. (2022). Learning by distilling context. arXiv preprint arXiv:2209.15189.\n"
  },
  {
    "path": "tinker_cookbook/recipes/prompt_distillation/create_data.py",
    "content": "import asyncio\nimport json\nimport re\nfrom pathlib import Path\nfrom typing import Any\n\nimport chz\nimport tinker\nfrom tqdm.asyncio import tqdm_asyncio\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nLANGUAGE_CLASSIFICATION_PROMPT = \"\"\"You are a precise language classifier.\n\nGoal: Classify the language of the provided text into exactly one of these labels:\nar (Arabic), de (German), el (Greek), en (English), es (Spanish), fr (French),\nhi (Hindi), ru (Russian), tr (Turkish), ur (Urdu), vi (Vietnamese),\nzh (Chinese - Simplified), ot (Other/Unknown).\n\nInstructions:\n1) Preprocess carefully (without changing the intended meaning):\n   - Trim whitespace.\n   - Ignore URLs, emails, file paths, hashtags, user handles, and emojis.\n   - Ignore numbers, math expressions, and standalone punctuation.\n   - If there is code, IGNORE code syntax (keywords, operators, braces) and focus ONLY on human language in comments and string literals.\n   - Preserve letters and diacritics; do NOT strip accents.\n   - If after ignoring the above there are no alphabetic letters left, output 'ot'.\n\n2) Script-based rules (highest priority):\n   - Devanagari script → hi.\n   - Greek script → el.\n   - Cyrillic script → ru.\n   - Han characters (中文) → zh. (Treat Traditional as zh too.)\n   - Arabic script → ar vs ur:\n       • If Urdu-only letters appear (e.g., ے, ڑ, ں, ھ, ٹ, ڈ, کھ, گ, چ with Urdu forms), or clear Urdu words, choose ur.\n       • Otherwise choose ar.\n   (If multiple scripts appear, pick the script that contributes the majority of alphabetic characters. If tied, go to step 5.)\n\n3) Latin-script heuristics (use when text is mainly Latin letters):\n   - vi: presence of Vietnamese-specific letters/diacritics (ă â ê ô ơ ư đ, plus dense diacritics across many words).\n   - tr: presence of Turkish-specific letters (ı İ ğ Ğ ş Ş ç Ç ö Ö ü Ü) and common function words (ve, bir, için, değil, ama, çok).\n   - de: presence of umlauts (ä ö ü) or ß and common function words (und, der, die, das, nicht, ist).\n   - es: presence of ñ, ¿, ¡ and common words (y, de, la, el, es, no, por, para, con, gracias, hola).\n   - fr: frequent French diacritics (é è ê à ç ô â î û ù) and common words (et, le, la, les, des, une, est, avec, pour, merci, bonjour).\n   - en: default among Latin languages if strong evidence for others is absent, but ONLY if English function words are present (the, and, is, are, to, of, in, for, on, with). If evidence is insufficient for any Latin language, prefer 'ot' over guessing.\n\n4) Named entities & loanwords:\n   - Do NOT decide based on a single proper noun, brand, or place name.\n   - Require at least two function words or repeated language-specific signals (diacritics/letters) before assigning a Latin-language label.\n\n5) Mixed-language text:\n   - Determine the dominant language by counting indicative tokens (language-specific letters/diacritics/function words) AFTER preprocessing.\n   - If two or more languages are equally dominant or the text is a deliberate multi-language mix, return 'ot'.\n\n6) Very short or noisy inputs:\n   - If the text is ≤2 meaningful words or too short to be confident, return 'ot' unless there is a very strong language-specific signal (e.g., “bonjour” → fr, “hola” → es).\n\n7) Transliteration/romanization:\n   - If Hindi/Urdu/Arabic/Chinese/Russian/Greek is written purely in Latin letters (romanized) without clear, repeated language-specific cue words, return 'ot'. (Only classify as hi/ur/ar/zh/ru/el when native scripts or highly distinctive romanized patterns are clearly present.)\n\n8) Code-heavy inputs:\n   - If the text is mostly code with minimal or no natural-language comments/strings, return 'ot'.\n   - If comments/strings clearly indicate a language per rules above, use that label.\n\n9) Ambiguity & confidence:\n   - When in doubt, choose 'ot' rather than guessing.\n\nOutput format:\n- Respond with EXACTLY one line: \"Final Answer: xx\"\n- Where xx ∈ {{ar, de, el, en, es, fr, hi, ru, tr, ur, vi, zh, ot}} and nothing else.\n\nText to classify:\n{text}\n\"\"\"\n\n\n@chz.chz\nclass Config:\n    output_file: str\n\n\ndef setup_clients():\n    # disable tokenizer parallelism warnings\n    import os\n\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n    print(\"Creating service client\")\n    service_client = tinker.ServiceClient()\n    print(\"Creating sampling client\")\n    sampling_client = service_client.create_sampling_client(base_model=\"Qwen/Qwen3-30B-A3B\")\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = renderers.get_renderer(\"qwen3\", tokenizer)\n\n    return sampling_client, tokenizer, renderer\n\n\nasync def create_data_async(cfg: Config, sampling_client: Any, tokenizer: Any, renderer: Any):\n    # read sentences from multilingual.txt file\n    with open(\"tinker_cookbook/example_data/multilingual.txt\") as f:\n        sentences = f.readlines()\n    sentences = [sentence.strip() for sentence in sentences]\n\n    print(f\"Loaded {len(sentences)} sentences\")\n\n    async def sample_from_model(\n        sentence: str,\n    ) -> tuple[str, str | None]:\n        prompt = LANGUAGE_CLASSIFICATION_PROMPT.format(text=sentence)\n        tokenized_prompt = tinker.ModelInput.from_ints(tokenizer.encode(prompt))\n        params = tinker.SamplingParams(\n            max_tokens=1000, temperature=0.15, stop=renderer.get_stop_sequences()\n        )\n        result = await sampling_client.sample_async(\n            prompt=tokenized_prompt, sampling_params=params, num_samples=1\n        )\n        response = tokenizer.decode(result.sequences[0].tokens)\n        # parse the final answer from the response using regex for example: Final Answer: xx where xx is two character label for each language and nothing else. xx is one of the following: en, fr, es, hi, ja, ko, ru, ot.\n        # the final answer is the xx part\n        search_response = re.search(r\"Final Answer: (\\w+)\", response)\n        final_answer = search_response.group(1) if search_response else None\n        return (sentence, final_answer)\n\n    answers: list[str | None] = []\n    questions: list[str] = []\n    for coro in tqdm_asyncio.as_completed(\n        [sample_from_model(s) for s in sentences], total=len(sentences)\n    ):\n        question, answer = await coro\n        answers.append(answer)\n        questions.append(question)\n\n    # save the input and final answer to a file\n    with open(cfg.output_file, \"w\") as f:\n        for question, answer in zip(questions, answers):\n            if answer is None:\n                continue\n            messages = {\n                \"messages\": [\n                    {\n                        \"role\": \"user\",\n                        \"content\": question,\n                    },\n                    {\n                        \"role\": \"assistant\",\n                        \"content\": answer,\n                    },\n                ],\n            }\n            f.write(json.dumps(messages) + \"\\n\")\n\n    return\n\n\ndef main(cfg: Config):\n    # check if the output file exists\n    output_path = Path(cfg.output_file)\n    if output_path.exists():\n        print(f\"Output file {cfg.output_file} already exists\")\n        return\n    elif not output_path.parent.exists():\n        # check if the output directory exists\n        print(f\"Output directory {output_path.parent} does not exist\")\n        print(f\"Creating directory {output_path.parent}\")\n        output_path.parent.mkdir(parents=True, exist_ok=True)\n\n    # Setup clients synchronously\n    sampling_client, tokenizer, renderer = setup_clients()\n\n    print(\"Sampling data\")\n    # Run async data creation\n    asyncio.run(create_data_async(cfg, sampling_client, tokenizer, renderer))\n    print(f\"Saved data to {cfg.output_file}\")\n\n\nif __name__ == \"__main__\":\n    chz.nested_entrypoint(main)\n"
  },
  {
    "path": "tinker_cookbook/recipes/prompt_distillation/train.py",
    "content": "\"\"\"\nCLI for prompt distillation training.\n\"\"\"\n\nimport asyncio\nfrom datetime import datetime\nfrom pathlib import Path\n\nimport chz\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.renderers import TrainOnWhat\nfrom tinker_cookbook.supervised import train\nfrom tinker_cookbook.supervised.data import FromConversationFileBuilder\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilderCommonConfig\nfrom tinker_cookbook.utils.lr_scheduling import LRSchedule\n\n\n@chz.chz\nclass CLIConfig:\n    # Required parameters\n    file_path: str = \"/tmp/tinker-datasets/prompt_distillation_lang.jsonl\"\n    log_path: str | None = None\n    model_name: str = \"Qwen/Qwen3-30B-A3B\"\n    load_checkpoint_path: str | None = None\n\n    # Training parameters\n    learning_rate: float = 1e-4\n    lr_schedule: LRSchedule = \"linear\"\n    num_epochs: int = 4\n\n    # Model parameters\n    lora_rank: int = 32\n\n    # Infrastructure parameters\n    base_url: str | None = None\n\n    # Checkpointing and evaluation\n    save_every: int = 20\n    eval_every: int = 5\n\n    # Dataset-specific parameters\n    renderer_name: str | None = None\n    train_on_what: TrainOnWhat = TrainOnWhat.ALL_ASSISTANT_MESSAGES\n    max_length: int = 32768\n    batch_size: int = 128\n\n    # Logging parameters\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\ndef cli_main(cli_config: CLIConfig):\n    # Build full config\n    model_name = cli_config.model_name.replace(\"/\", \"-\")\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    run_name = f\"prompt_distillation-{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.batch_size}batch-{date_and_time}\"\n\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/prompt_distillation/{run_name}\"\n\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    # make sure the data file exists\n    if not Path(cli_config.file_path).exists():\n        raise FileNotFoundError(f\"Data file not found: {cli_config.file_path}\")\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    renderer_name = checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=cli_config.model_name,\n        renderer_name=renderer_name,\n        max_length=cli_config.max_length,\n        batch_size=cli_config.batch_size,\n        train_on_what=cli_config.train_on_what,\n    )\n\n    dataset = FromConversationFileBuilder(\n        common_config=common_config,\n        file_path=cli_config.file_path,\n    )\n\n    config = train.Config(\n        log_path=log_path,\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        dataset_builder=dataset,\n        learning_rate=cli_config.learning_rate,\n        lr_schedule=cli_config.lr_schedule,\n        num_epochs=cli_config.num_epochs,\n        base_url=cli_config.base_url,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        lora_rank=cli_config.lora_rank,\n        save_every=cli_config.save_every,\n        eval_every=cli_config.eval_every,\n        max_steps=cli_config.max_steps,\n    )\n    asyncio.run(train.main(config))\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    cli_main(cli_config)\n"
  },
  {
    "path": "tinker_cookbook/recipes/rl_basic.py",
    "content": "import asyncio\nimport sys\n\nimport chz\n\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.recipes.math_rl.math_env import Gsm8kDatasetBuilder\nfrom tinker_cookbook.rl import train\n\n\ndef build_config_blueprint() -> chz.Blueprint[train.Config]:\n    model_name = \"meta-llama/Llama-3.1-8B\"\n    renderer_name = model_info.get_recommended_renderer_name(model_name)\n    builder = Gsm8kDatasetBuilder(\n        batch_size=128,\n        group_size=16,\n        renderer_name=renderer_name,\n        model_name_for_tokenizer=model_name,\n    )\n\n    return chz.Blueprint(train.Config).apply(\n        {\n            \"model_name\": model_name,\n            \"renderer_name\": renderer_name,\n            \"log_path\": \"/tmp/tinker-examples/rl_basic\",\n            \"dataset_builder\": builder,\n            \"learning_rate\": 4e-5,\n            \"max_tokens\": 256,\n            \"eval_every\": 0,\n        }\n    )\n\n\ndef main(config: train.Config):\n    # Avoid clobbering log dir from your previous run:\n    cli_utils.check_log_dir(config.log_path, behavior_if_exists=\"ask\")\n    asyncio.run(train.main(config))\n\n\nif __name__ == \"__main__\":\n    blueprint = build_config_blueprint()\n    blueprint.make_from_argv(sys.argv[1:])\n    main(blueprint.make())\n"
  },
  {
    "path": "tinker_cookbook/recipes/rl_loop.py",
    "content": "\"\"\"\nMinimal RL training loop using GRPO-style reward centering.\n\nVariable naming convention (see CONTRIBUTING.md):\n    _P: Problem dimension (different questions/prompts in a batch)\n    _G: Group dimension (multiple rollouts per problem for variance reduction)\n    _T: Token/Time dimension (sequence positions)\n    _D: Datum dimension (training examples after flattening)\n\nExample: `tokens_G_T` is a list of token sequences, one per group member.\nIn this script, datums_D has size P*G (one datum per rollout).\n\"\"\"\n\nimport logging\nimport time\nfrom concurrent.futures import Future\n\nimport chz\nimport datasets\nimport tinker\nimport torch\nfrom tinker import types\nfrom tinker.types.tensor_data import TensorData\nfrom tqdm import tqdm\n\nfrom tinker_cookbook import checkpoint_utils, model_info, renderers\nfrom tinker_cookbook.recipes.math_rl.math_env import extract_gsm8k_final_answer\nfrom tinker_cookbook.recipes.math_rl.math_grading import extract_boxed, grade_answer\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.utils import ml_log\n\nlogger = logging.getLogger(__name__)\nlogging.getLogger(\"httpx\").setLevel(logging.WARN)\n\n\n@chz.chz\nclass Config:\n    base_url: str | None = None\n    log_path: str = \"/tmp/tinker-examples/rl-loop\"\n    model_name: str = \"meta-llama/Llama-3.1-8B\"\n    batch_size: int = 128\n    group_size: int = 16\n    learning_rate: float = 4e-5\n    lora_rank: int = 32\n    save_every: int = 20  # 0 = disabled\n    max_tokens: int = 256\n    ttl_seconds: int | None = 604800  # 7 days\n\n\ndef get_reward(response: str, answer: str) -> float:\n    try:\n        given_answer = extract_boxed(response)\n        ground_truth = extract_gsm8k_final_answer(answer)\n        return 1.0 if grade_answer(given_answer, ground_truth) else 0.0\n    except ValueError:\n        return 0.0\n\n\ndef main(config: Config):\n    # Setup logging\n    ml_logger = ml_log.setup_logging(\n        log_dir=config.log_path,\n        wandb_project=None,\n        wandb_name=None,\n        config=config,\n        do_configure_logging_module=True,\n    )\n\n    # Get tokenizer and renderer\n    tokenizer = get_tokenizer(config.model_name)\n    renderer_name = model_info.get_recommended_renderer_name(config.model_name)\n    renderer = renderers.get_renderer(renderer_name, tokenizer)\n    logger.info(f\"Using renderer: {renderer_name}\")\n\n    # Load GSM8K dataset\n    logger.info(\"Loading dataset...\")\n    dataset = datasets.load_dataset(\"openai/gsm8k\", \"main\")\n    assert isinstance(dataset, datasets.DatasetDict)\n    train_dataset = dataset[\"train\"]\n\n    question_suffix = \" Provide a numerical answer without units, written inside \\\\boxed{}.\"\n\n    convo_prefix = [\n        {\n            \"role\": \"user\",\n            \"content\": \"How many r's are in strawberry?\" + question_suffix,\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": \"Let's spell the word out and number all the letters: 1) s 2) t 3) r 4) a 5) w 6) b 7) e 8) r 9) r 10) y. We have r's at positions 3, 8, and 9. \\\\boxed{3}\",\n        },\n    ]\n\n    n_train_batches = len(train_dataset) // config.batch_size\n\n    # Setup training client\n    service_client = tinker.ServiceClient(base_url=config.base_url)\n\n    resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)\n    if resume_info:\n        training_client = service_client.create_training_client_from_state_with_optimizer(\n            resume_info.state_path\n        )\n        start_batch = resume_info.batch\n        logger.info(f\"Resuming from batch {start_batch}\")\n    else:\n        training_client = service_client.create_lora_training_client(\n            base_model=config.model_name, rank=config.lora_rank\n        )\n        start_batch = 0\n\n    sampling_params = tinker.types.SamplingParams(\n        max_tokens=config.max_tokens,\n        stop=renderer.get_stop_sequences(),\n    )\n    # Optimizer step\n    adam_params = types.AdamParams(\n        learning_rate=config.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8\n    )\n\n    logger.info(f\"Training for {n_train_batches} batches\")\n\n    # Main training loop\n    for batch_idx in range(start_batch, n_train_batches):\n        t_start = time.time()\n        metrics: dict[str, float] = {\n            \"progress/batch\": batch_idx,\n            \"optim/lr\": config.learning_rate,\n            \"progress/done_frac\": (batch_idx + 1) / n_train_batches,\n        }\n\n        # Save checkpoint\n        if config.save_every > 0 and batch_idx % config.save_every == 0 and batch_idx > 0:\n            checkpoint_utils.save_checkpoint(\n                training_client=training_client,\n                name=f\"{batch_idx:06d}\",\n                log_path=config.log_path,\n                kind=\"state\",\n                loop_state={\"batch\": batch_idx},\n                ttl_seconds=config.ttl_seconds,\n            )\n\n        # Get training batch and convert to datums online\n        batch_start = batch_idx * config.batch_size\n        batch_end = min((batch_idx + 1) * config.batch_size, len(train_dataset))\n        batch_rows = train_dataset.select(range(batch_start, batch_end))\n\n        sampling_client = training_client.save_weights_and_get_sampling_client()\n\n        datums_D: list[types.Datum] = []\n        rewards_P: list[float] = []\n        futures_P: list[Future[types.SampleResponse]] = []\n        prompts_P: list[types.ModelInput] = []\n        for question in batch_rows[\"question\"]:\n            convo = [\n                *convo_prefix,\n                {\"role\": \"user\", \"content\": question + question_suffix},\n            ]\n            model_input = renderer.build_generation_prompt(convo)\n\n            # Generate group_size responses in a single call\n            future = sampling_client.sample(\n                prompt=model_input,\n                num_samples=config.group_size,\n                sampling_params=sampling_params,\n            )\n            futures_P.append(future)\n            prompts_P.append(model_input)\n\n        for future, prompt, answer in tqdm(\n            zip(futures_P, prompts_P, batch_rows[\"answer\"]),\n            total=len(futures_P),\n            desc=f\"Sampling batch {batch_idx}\",\n        ):\n            sample_result = future.result()\n            rewards_G: list[float] = []\n            sampled_tokens_G_T: list[list[int]] = []\n            logprobs_G_T: list[list[float]] = []\n            for sequence in sample_result.sequences:\n                sampled_tokens = sequence.tokens\n                sampled_logprobs = sequence.logprobs\n                assert sampled_logprobs is not None\n\n                sampled_tokens_G_T.append(sampled_tokens)\n                logprobs_G_T.append(sampled_logprobs)\n\n                parsed_message, _ = renderer.parse_response(sampled_tokens)\n                content = renderers.get_text_content(parsed_message)\n                reward = get_reward(content, answer)\n                rewards_G.append(reward)\n\n            mean_reward = sum(rewards_G) / len(rewards_G)\n            advantages_G = [reward - mean_reward for reward in rewards_G]\n            rewards_P.append(mean_reward)\n\n            # check if all advantages are zero\n            if all(advantage == 0.0 for advantage in advantages_G):\n                # Skip question because all advantages are the same\n                continue\n\n            for sampled_tokens, logprobs, advantage in zip(\n                sampled_tokens_G_T, logprobs_G_T, advantages_G\n            ):\n                ob_len = prompt.length - 1\n                model_input = prompt.append(types.EncodedTextChunk(tokens=sampled_tokens[:-1]))\n                target_tokens = [0] * ob_len + sampled_tokens\n                padded_logprobs = [0.0] * ob_len + logprobs\n                padded_advantages = [0.0] * ob_len + [advantage] * (model_input.length - ob_len)\n                assert (\n                    model_input.length\n                    == len(target_tokens)\n                    == len(padded_logprobs)\n                    == len(padded_advantages)\n                ), (\n                    f\"model_input.length: {model_input.length}, len(target_tokens): {len(target_tokens)}, \"\n                    f\"len(padded_logprobs): {len(padded_logprobs)}, len(padded_advantages): {len(padded_advantages)}\"\n                )\n                datum = types.Datum(\n                    model_input=model_input,\n                    loss_fn_inputs={\n                        \"target_tokens\": TensorData.from_torch(torch.tensor(target_tokens)),\n                        \"logprobs\": TensorData.from_torch(torch.tensor(padded_logprobs)),\n                        \"advantages\": TensorData.from_torch(torch.tensor(padded_advantages)),\n                    },\n                )\n                datums_D.append(datum)\n\n        # Training step\n        if len(datums_D) == 0:\n            logger.warning(\"Batch %d: all advantages zero, skipping training step\", batch_idx)\n        else:\n            fwd_bwd_future = training_client.forward_backward(\n                datums_D, loss_fn=\"importance_sampling\"\n            )\n            optim_step_future = training_client.optim_step(adam_params)\n            _fwd_bwd_result = fwd_bwd_future.result()\n            optim_result = optim_step_future.result()\n\n            if optim_result.metrics:\n                metrics.update(optim_result.metrics)\n\n        # Log metrics\n        metrics[\"time/total\"] = time.time() - t_start\n        metrics[\"reward/total\"] = sum(rewards_P) / len(rewards_P)\n        ml_logger.log_metrics(metrics, step=batch_idx)\n\n        # Save final checkpoint\n    checkpoint_utils.save_checkpoint(\n        training_client=training_client,\n        name=\"final\",\n        log_path=config.log_path,\n        kind=\"both\",\n        loop_state={\"batch\": n_train_batches},\n        ttl_seconds=None,\n    )\n    ml_logger.close()\n    logger.info(\"Training completed\")\n\n\nif __name__ == \"__main__\":\n    chz.nested_entrypoint(main)\n"
  },
  {
    "path": "tinker_cookbook/recipes/rubric/README.md",
    "content": "# Rubric-based Grading for LLMs\n\n- [`data.py`](./data.py) contains the definition for the datapoint class. Each datapoint consists of a conversation prefix and a list of rubric items.\n- [`generate_data.py`](./generate_data.py) generates some example datapoints if you want to run our demo on addition.\n- [`env.py`](./env.py) determines what each rollout will do. It will let the policy read the prefix, generate a response, ask a grader LLM to grade based on a list of rubric items, and finally provide a reward by summing the response of each grader.\n- [`train.py`](./train.py) allows you to train LLMs on any dataset saved in our format (specified in `data.py`). The default script will train on the addition task, whose data is generated by `generate_data.py`.\n- [`prometheus_experimental.py`](./prometheus_experimental.py) contains a script to train the LLMs based on the rubrics from the [`prometheus-eval/Feedback-Collection`](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/viewer/default/train?row=0&views%5B%5D=train) dataset. It is experimental though -- even though the reward goes up, there is no guarantee that the model is actually better. We hope our script serves as a starting point, and more research is needed.\n\n\n## A simple example of using a grader LLM with rubrics\n\nWe show how to use a rubric-based LLM to provide a reward for an addition task. E.g.\n\n```\n**User**: What's 233 + 100?\n**Assistant**: 333\n```\n\nUsually, this could be graded by matching the number to the ground truth 333 without needing an LLM. However, for pedagogical purposes, we will grade the response using a language model with a rubric. That is, we will ask a language model \"Does the assistant answer 333?\"\n\n### Generate an example dataset\n\nTo run this, first generate a dataset:\n\n```\npython -m tinker_cookbook.recipes.rubric.generate_data\n```\n\nThen you will see two `jsonl` files generated, one for training, one for testing. For example, if you look into `tinker_cookbook/example_data/example_rubric_train.jsonl`, each datapoint consists of\n- a convo (the conversation prefix that the policy sees)\n- rubric_items: a list of rubric items that specify what is a good response, how the grader should format the response, and how the grading result should be extracted.\n\n```\n{\n  \"convo\": [\n    {\n      \"role\": \"user\",\n      \"content\": \"What is 4 + 5?\"\n    },\n    {\n      \"role\": \"assistant\",\n      \"content\": \"9\"\n    },\n    {\n      \"role\": \"user\",\n      \"content\": \"What is 122 + 12?\"\n    }\n  ],\n  \"rubric_items\": [\n    {\n      \"rubric_str\": \"Does the chatbot correctly get the answer 134?\",\n      \"extraction_regex\": \"<score>(.*)</score>\",\n      \"grader_output_format_instruction\": \"Please output your score between 0 and 1 wrapped in <score> ... </score>\"\n    }\n  ]\n}\n```\n\n### Debugging and Printing What Happens During Rollouts\n\nRun\n```\npython -m tinker_cookbook.recipes.rubric.debug_env\n```\n\nYou can see the message that the policy sees, its response, the grader input, and the grader output.\n\n<img width=\"1168\" height=\"771\" alt=\"Debug output showing the conversation context, policy response, grader prompt, and extracted score\" src=\"https://github.com/user-attachments/assets/9f4e3c89-f21e-49b0-96d6-e2f27bd21b43\" />\n\n\n### An example training run\n\nTo train the LLM to add with a rubric-based LLM, run\n```\npython -m tinker_cookbook.recipes.rubric.train\n```\n\nYou can see the reward quickly goes up.\n\n<img width=\"705\" height=\"279\" alt=\"Training metrics showing reward increasing over training steps for the addition task\" src=\"https://github.com/user-attachments/assets/2f825805-20a7-4cf3-8d06-55d5e9a98098\" />\n\n### A more realistic dataset\n\nWe take the `prometheus-eval/Feedback-Collection` dataset from [Hugging Face](https://huggingface.co/datasets/prometheus-eval/Feedback-Collection/), which contains rubrics to grade general chat responses. Run the following to kick off training:\n\n```\npython -m tinker_cookbook.recipes.rubric.prometheus_experimental\n```\n\nWe can see that the reward climbs up steadily.\n\n<img width=\"1086\" height=\"514\" alt=\"Training metrics showing reward climbing steadily over training steps for the Prometheus dataset\" src=\"https://github.com/user-attachments/assets/8877ea6c-b9ea-46da-b995-046bbd3e7c80\" />\n\nNote that this training recipe is experimental -- to make the performance better we may need to fine-tune the grader LLM as well. We hope our code serves as a starting point for you to improve rubric-based grading for training LLMs!\n"
  },
  {
    "path": "tinker_cookbook/recipes/rubric/data.py",
    "content": "import json\nimport re\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, TypeAlias\n\nimport chz\n\nfrom tinker_cookbook.renderers import (\n    Message,\n    Role,\n)\n\nConversation: TypeAlias = list[Message]\n\n\n@dataclass\nclass Rubric:\n    \"\"\"\n    A rubric should specify 1) what counts as a good response, 2) how the grader language model should output the score, and 3) how to extract the score from the grader's response.\n    \"\"\"\n\n    rubric_str: str\n    extraction_regex: str = r\"<score>(.*)</score>\"\n    grader_output_format_instruction: str = (\n        \"Please output your score between 0 and 1 wrapped in <score> ... </score>\"\n    )\n\n    def _convert_role(self, role: Role) -> str:\n        return \"Human\" if role in (\"user\", \"system\") else \"Chatbot\"\n\n    def _flatten_convo(self, convo: Conversation) -> str:\n        \"\"\"\n        Convert the whole conversation (user's turns + assistant's turns) into a single string. E.g.\n        \\n\\nHuman: ....\n        \\n\\nChatbot: ...\n        \\n\\nHuman: ...\n        \\n\\nChatbot: ...\n        \"\"\"\n        return \"\\n\\n\".join(\n            [f\"{self._convert_role(message['role'])}: {message['content']}\" for message in convo]\n        )\n\n    def get_grader_prompt(self, convo: Conversation) -> Conversation:\n        \"\"\"\n        Create a prompt for the grader to grade the conversation based on the rubric.\n        The prompt separates the context (prior turns) from the completion (last assistant message)\n        so the grader focuses on grading the most recent response.\n        \"\"\"\n        # Separate context from the completion to grade\n        context = convo[:-1]\n        completion = convo[-1]\n\n        lines = [\n            \"I will show you a conversation context, a chatbot completion to grade, and a rubric.\",\n            \"Please grade the chatbot's completion based on the rubric.\",\n            \"\",\n            \"<context>\",\n            self._flatten_convo(context) if context else \"(No prior context)\",\n            \"</context>\",\n            \"\",\n            \"<completion_to_grade>\",\n            f\"Chatbot: {completion['content']}\",\n            \"</completion_to_grade>\",\n            \"\",\n            \"<rubric>\",\n            self.rubric_str,\n            \"</rubric>\",\n            \"\",\n            f\"Please grade the chatbot's completion based on the rubric. {self.grader_output_format_instruction}\",\n        ]\n        return [\n            {\n                \"role\": \"user\",\n                \"content\": \"\\n\".join(lines),\n            }\n        ]\n\n    def extract_score(self, response: str) -> float:\n        match = re.search(self.extraction_regex, response, re.DOTALL)\n        if match is not None:\n            try:\n                return float(match.group(1))\n            except ValueError:\n                print(f\"Warning: Failed to extract score from grader response: {response}\")\n                return 0.0\n        else:\n            print(f\"Warning: Failed to extract score from grader response: {response}\")\n            return 0.0\n\n    def to_dict(self) -> dict[str, str]:\n        return {\n            \"rubric_str\": self.rubric_str,\n            \"extraction_regex\": self.extraction_regex,\n            \"grader_output_format_instruction\": self.grader_output_format_instruction,\n        }\n\n    def to_json(self) -> str:\n        return json.dumps(self.to_dict())\n\n    @staticmethod\n    def from_dict(d: dict[str, str]) -> \"Rubric\":\n        return Rubric(\n            rubric_str=d[\"rubric_str\"],\n            extraction_regex=d[\"extraction_regex\"],\n            grader_output_format_instruction=d[\"grader_output_format_instruction\"],\n        )\n\n    @staticmethod\n    def from_json(json_str: str) -> \"Rubric\":\n        return Rubric.from_dict(json.loads(json_str))\n\n\n@dataclass(frozen=True)\nclass RubricBasedDatapoint:\n    \"\"\"\n    A rubric-based datapoint contains a conversation and a rubric.\n    In this task, the policy model sees the conversation, create a response, and then the grader language model grades the response based on the rubric.\n    \"\"\"\n\n    convo: Conversation\n    rubric_items: Sequence[Rubric]\n\n    def to_json(self) -> str:\n        return json.dumps(\n            {\n                \"convo\": self.convo,\n                \"rubric_items\": [rubric.to_dict() for rubric in self.rubric_items],\n            }\n        )\n\n    @staticmethod\n    def from_json(json_str: str) -> \"RubricBasedDatapoint\":\n        d = json.loads(json_str)\n        return RubricBasedDatapoint(\n            convo=d[\"convo\"],\n            rubric_items=[Rubric.from_dict(rubric) for rubric in d[\"rubric_items\"]],\n        )\n\n\n@chz.chz\nclass RubricDatapointListBuilder:\n    def __call__(self) -> Sequence[RubricBasedDatapoint]:\n        \"\"\"Load and return a sequence of rubric-based datapoints.\"\"\"\n        raise NotImplementedError(\"Subclass must implement this method\")\n\n\n@chz.chz\nclass RubricDatapointListBuilderFromJsonl(RubricDatapointListBuilder):\n    jsonl_path: str\n\n    def __call__(self) -> Sequence[RubricBasedDatapoint]:\n        if not Path(self.jsonl_path).exists():\n            raise FileNotFoundError(\n                f\"Data file not found: {self.jsonl_path}\\n\"\n                f\"Please generate the example data first by running:\\n\"\n                f\"  python -m tinker_cookbook.recipes.rubric.generate_data\"\n            )\n        datapoints = []\n        with open(self.jsonl_path) as f:\n            for line in f:\n                datapoints.append(RubricBasedDatapoint.from_json(line))\n        return datapoints\n\n\n@chz.chz\nclass PrometheusDatapointListBuilder(RubricDatapointListBuilder):\n    data_path: str = \"prometheus-eval/Feedback-Collection\"\n\n    def __call__(self) -> Sequence[RubricBasedDatapoint]:\n        from datasets import load_dataset\n\n        train_dataset = load_dataset(self.data_path)[\"train\"]\n        return [self.build_rubric_datapoint(item) for item in train_dataset]  # type: ignore\n\n    def build_rubric_datapoint(self, item: dict[str, Any]) -> RubricBasedDatapoint:\n        convo: Conversation = [\n            {\"role\": \"user\", \"content\": item[\"orig_instruction\"]},\n        ]\n\n        rubric_lines = [\n            f\"Your job is to evaluate the following: {item['orig_criteria']}. Your response should be a score between 1 to 5.\",\n            \"Here is the calibration for each score:\",\n        ]\n        for i in range(1, 6):\n            rubric_lines.append(f\"<score>{i}.0</score>: {item[f'orig_score{i}_description']}\")\n        rubric_lines.append(\n            f\"Here is a reference response that achieved a score of 5: {item['orig_reference_answer']}\"\n        )\n        rubric_text = \"\\n\".join(rubric_lines)\n\n        rubric = Rubric(\n            rubric_str=rubric_text,\n            extraction_regex=r\"<score>(.*)</score>\",\n            grader_output_format_instruction=\"Please output your score between 1 and 5 wrapped in <score> ... </score>\",\n        )\n\n        return RubricBasedDatapoint(\n            convo=convo,\n            rubric_items=[rubric],\n        )\n"
  },
  {
    "path": "tinker_cookbook/recipes/rubric/debug_env.py",
    "content": "import asyncio\n\nimport tinker\n\nfrom tinker_cookbook import model_info\nfrom tinker_cookbook.completers import TinkerMessageCompleter, TinkerTokenCompleter\nfrom tinker_cookbook.recipes.rubric.env import Rubric, RubricBasedDatapoint, RubricGradedEnv\nfrom tinker_cookbook.renderers import get_renderer\nfrom tinker_cookbook.rl.rollouts import do_single_rollout\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\ndef get_addition_datapoint() -> RubricBasedDatapoint:\n    datapoint = RubricBasedDatapoint(\n        convo=[\n            {\"role\": \"user\", \"content\": \"What is 4 + 5?\"},\n            {\"role\": \"assistant\", \"content\": \"9\"},\n            {\"role\": \"user\", \"content\": \"What is 125 + 311?\"},\n        ],\n        rubric_items=[\n            Rubric(rubric_str=\"Does the chatbot correctly get the answer 436?\"),\n            Rubric(rubric_str=\"Does the chatbot provide an answer without saying anything else?\"),\n        ],\n    )\n\n    return datapoint\n\n\ndef get_prometheus_datapoint() -> RubricBasedDatapoint:\n    from tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder\n\n    datapoint = PrometheusDatapointListBuilder()()\n    datapoint = datapoint[0]\n    return datapoint\n\n\nasync def main(datapoint: RubricBasedDatapoint):\n    # Configuration parameters\n    policy_name = \"meta-llama/Llama-3.1-8B-Instruct\"\n    grader_name = \"Qwen/Qwen3-30B-A3B-Instruct-2507\"\n    policy_max_tokens = 64\n    grader_max_tokens = 64\n\n    service_client = tinker.ServiceClient()\n    policy = TinkerTokenCompleter(\n        sampling_client=service_client.create_sampling_client(base_model=policy_name),\n        max_tokens=policy_max_tokens,\n    )\n    policy_renderer = get_renderer(\n        model_info.get_recommended_renderer_name(policy_name), get_tokenizer(policy_name)\n    )\n    grader = TinkerMessageCompleter(\n        sampling_client=service_client.create_sampling_client(base_model=grader_name),\n        renderer=get_renderer(\n            model_info.get_recommended_renderer_name(grader_name), get_tokenizer(grader_name)\n        ),\n        max_tokens=grader_max_tokens,\n    )\n\n    env = RubricGradedEnv(\n        renderer=policy_renderer,\n        datapoint=datapoint,\n        grader_llm=grader,\n        debug=True,\n    )\n\n    await do_single_rollout(policy, env)\n\n\nif __name__ == \"__main__\":\n    dataset = \"addition\"\n\n    if dataset == \"addition\":\n        datapoint = get_addition_datapoint()\n        asyncio.run(main(datapoint))\n    elif dataset == \"prometheus\":\n        datapoint = get_prometheus_datapoint()\n        asyncio.run(main(datapoint))\n    else:\n        raise ValueError(f\"Unknown dataset: {dataset}\")\n"
  },
  {
    "path": "tinker_cookbook/recipes/rubric/env.py",
    "content": "import asyncio\nimport json\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\n\nimport chz\nimport tinker\nfrom termcolor import colored\nfrom tinker.types import ModelInput\n\nfrom tinker_cookbook import model_info\nfrom tinker_cookbook.completers import MessageCompleter, StopCondition, TinkerMessageCompleter\nfrom tinker_cookbook.recipes.rubric.data import (\n    Conversation,\n    Rubric,\n    RubricBasedDatapoint,\n    RubricDatapointListBuilder,\n)\nfrom tinker_cookbook.renderers import Renderer, get_renderer\nfrom tinker_cookbook.rl.types import (\n    Action,\n    Env,\n    EnvGroupBuilder,\n    RLDataset,\n    RLDatasetBuilder,\n    StepResult,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.utils import logtree\nfrom tinker_cookbook.utils.logtree_formatters import ConversationFormatter\n\n\nclass RubricGradedEnv(Env):\n    def __init__(\n        self,\n        renderer: Renderer,\n        datapoint: RubricBasedDatapoint,\n        grader_llm: MessageCompleter,\n        debug: bool = False,\n        format_coef: float = 0.1,\n    ):\n        \"\"\"\n        Initialize the RubricGradedEnv. In this environment, the policy model sees the conversation,\n        creates a response, and then the grader language model grades the response based on the rubric.\n        \"\"\"\n        self.renderer = renderer\n        self.datapoint = datapoint\n        self.grader_llm = grader_llm\n        self.debug = debug\n        self.format_coef = format_coef\n\n    @property\n    def rubric_items(self) -> Sequence[Rubric]:\n        return self.datapoint.rubric_items\n\n    @property\n    def convo(self) -> Conversation:\n        return self.datapoint.convo\n\n    @property\n    def stop_condition(self) -> StopCondition:\n        return self.renderer.get_stop_sequences()\n\n    async def initial_observation(self) -> tuple[ModelInput, StopCondition]:\n        return self.renderer.build_generation_prompt(self.convo), self.stop_condition\n\n    async def _grade_with_rubric(self, convo: Conversation, rubric: Rubric) -> tuple[float, str]:\n        # this is the conversation for the grader\n        # effectively it's just one user turn\n        grader_prompt = rubric.get_grader_prompt(convo)\n\n        # obtain the response from the grader and convert it to a score\n        grader_response = await self.grader_llm(grader_prompt)\n        grader_response_content = grader_response[\"content\"]\n        assert isinstance(grader_response_content, str), \"Grader response content must be a string\"\n        score = rubric.extract_score(grader_response_content)\n        if self.debug:\n            print(colored(\"=\" * 80, \"yellow\"))\n            print(colored(\"DEBUG: First Turn of Grader Prompt\", \"yellow\"))\n            print(colored(\"=\" * 80, \"yellow\"))\n            print(colored(grader_prompt[0][\"content\"], \"yellow\") + \"\\n\")\n\n            print(colored(\"=\" * 80, \"magenta\"))\n            print(colored(\"DEBUG: Score\", \"magenta\"))\n            print(colored(\"=\" * 80, \"magenta\"))\n            print(colored(f\"Grader Response: {grader_response_content}\", \"magenta\") + \"\\n\")\n            print(colored(f\"Extracted Score: {score}\", \"magenta\") + \"\\n\")\n        return score, grader_response_content\n\n    async def step(self, action: Action) -> StepResult:\n        with logtree.scope_header(\"Prompt\"):\n            logtree.log_formatter(ConversationFormatter(messages=self.convo))\n\n        # obtain the policy action message\n        (policy_action_message, parse_success) = self.renderer.parse_response(action)\n        parse_success_bool = bool(parse_success)\n        format_score = float(parse_success_bool)\n\n        if self.debug:\n            print(\"\\n\" + colored(\"=\" * 80, \"blue\"))\n            print(colored(\"DEBUG: Original Conversation (self.convo)\", \"blue\"))\n            print(colored(\"=\" * 80, \"blue\"))\n            print(colored(json.dumps(self.convo, indent=2), \"blue\") + \"\\n\")\n\n            print(colored(\"=\" * 80, \"green\"))\n            print(colored(\"DEBUG: Policy Action Message\", \"green\"))\n            print(colored(\"=\" * 80, \"green\"))\n            print(colored(json.dumps(policy_action_message, indent=2), \"green\") + \"\\n\")\n            print(colored(f\"Parse Success: {parse_success}\", \"green\") + \"\\n\")\n\n        with logtree.scope_header(\"Policy Response\"):\n            logtree.log_formatter(ConversationFormatter(messages=[policy_action_message]))\n            logtree.log_text(f\"Parse success: {parse_success}\")\n\n        convo = self.convo + [policy_action_message]\n        results = await asyncio.gather(\n            *[self._grade_with_rubric(convo, rubric_item) for rubric_item in self.rubric_items]\n        )\n        scores = [score for score, _ in results]\n        avg_score = sum(scores) / len(scores)\n\n        with logtree.scope_header(\"Rubric Grades\"):\n            rows = []\n            for idx, (rubric_item, (score, grader_response)) in enumerate(\n                zip(self.rubric_items, results, strict=True),\n                start=1,\n            ):\n                rows.append(\n                    {\n                        \"#\": idx,\n                        \"score\": f\"{score:.3f}\",\n                        \"criterion\": rubric_item.rubric_str[:120]\n                        + (\"...\" if len(rubric_item.rubric_str) > 120 else \"\"),\n                    }\n                )\n                with logtree.scope_header(f\"Rubric {idx}: score={score:.3f}\"):\n                    logtree.log_text(f\"Criterion: {rubric_item.rubric_str}\")\n                    logtree.details(grader_response, summary=\"Model output\", pre=True)\n            logtree.table(rows, caption=\"Per-rubric scores\")\n\n        # Apply format penalty similar to ProblemEnv\n        format_penalty = self.format_coef * (format_score - 1)\n        total_reward = format_penalty + avg_score\n\n        with logtree.scope_header(\"Reward Terms\"):\n            logtree.table_from_dict(\n                {\n                    \"rubric_score_mean\": f\"{avg_score:.3f}\",\n                    \"format_parse_success\": parse_success_bool,\n                    \"format_penalty\": f\"{format_penalty:.3f}\",\n                    \"total_reward\": f\"{total_reward:.3f}\",\n                },\n                caption=\"Per-step reward breakdown\",\n            )\n\n        return StepResult(\n            reward=total_reward,\n            episode_done=True,\n            next_observation=self.renderer.build_generation_prompt(convo),\n            next_stop_condition=self.stop_condition,\n            metrics={\n                \"format\": format_score,\n                \"rubric_score\": avg_score,\n            },\n            logs={\n                \"parse_success\": int(parse_success_bool),\n                \"num_rubrics\": len(self.rubric_items),\n            },\n        )\n\n\n@dataclass(frozen=True)\nclass RubricGradedEnvGroupBuilder(EnvGroupBuilder):\n    renderer: Renderer\n    datapoint: RubricBasedDatapoint\n    grader_llm: MessageCompleter\n    group_size: int\n\n    async def make_envs(self) -> Sequence[RubricGradedEnv]:\n        return [\n            RubricGradedEnv(\n                renderer=self.renderer,\n                datapoint=self.datapoint,\n                grader_llm=self.grader_llm,\n            )\n            for _ in range(self.group_size)\n        ]\n\n\n@dataclass(frozen=True)\nclass RubricGradedDataset(RLDataset):\n    renderer: Renderer\n    batch_size: int\n    group_size: int\n    datapoints: Sequence[RubricBasedDatapoint]\n    grader_llm: MessageCompleter\n\n    def get_batch(self, index: int) -> Sequence[RubricGradedEnvGroupBuilder]:\n        batch = [\n            RubricGradedEnvGroupBuilder(\n                renderer=self.renderer,\n                datapoint=self.datapoints[index * self.batch_size + i],\n                grader_llm=self.grader_llm,\n                group_size=self.group_size,\n            )\n            for i in range(self.batch_size)\n        ]\n        return batch\n\n    def __len__(self) -> int:\n        return len(self.datapoints) // self.batch_size\n\n\n@chz.chz\nclass RubricGradedDatasetBuilder(RLDatasetBuilder):\n    renderer_name: str\n    model_name_for_tokenizer: str\n    batch_size: int\n    train_group_size: int\n    test_group_size: int = 1\n\n    train_datapoint_list_builder: RubricDatapointListBuilder\n    test_datapoint_list_builder: RubricDatapointListBuilder | None = None\n\n    base_url: str | None = None\n    grader_llm_name: str = \"Qwen/Qwen3-30B-A3B-Instruct-2507\"\n\n    def _get_grader_llm(self) -> MessageCompleter:\n        tokenizer = get_tokenizer(self.grader_llm_name)\n        renderer_name = model_info.get_recommended_renderer_name(self.grader_llm_name)\n        renderer = get_renderer(name=renderer_name, tokenizer=tokenizer)\n        service_client = tinker.ServiceClient(base_url=self.base_url)\n        sampling_client = service_client.create_sampling_client(base_model=self.grader_llm_name)\n        return TinkerMessageCompleter(\n            sampling_client=sampling_client, renderer=renderer, max_tokens=2048\n        )\n\n    async def __call__(self) -> tuple[RubricGradedDataset, RubricGradedDataset | None]:\n        train_datapoints = self.train_datapoint_list_builder()\n        test_datapoints = None\n        if self.test_datapoint_list_builder is not None:\n            test_datapoints = self.test_datapoint_list_builder()\n\n        renderer = get_renderer(\n            name=self.renderer_name, tokenizer=get_tokenizer(self.model_name_for_tokenizer)\n        )\n\n        assert train_datapoints is not None, \"Train datapoints are required\"\n        train_dataset = RubricGradedDataset(\n            renderer=renderer,\n            batch_size=self.batch_size,\n            group_size=self.train_group_size,\n            datapoints=train_datapoints,\n            grader_llm=self._get_grader_llm(),\n        )\n        if test_datapoints is None:\n            return train_dataset, None\n        else:\n            test_dataset = RubricGradedDataset(\n                renderer=renderer,\n                batch_size=len(test_datapoints),\n                group_size=self.test_group_size,\n                datapoints=test_datapoints,\n                grader_llm=self._get_grader_llm(),\n            )\n            return train_dataset, test_dataset\n"
  },
  {
    "path": "tinker_cookbook/recipes/rubric/generate_data.py",
    "content": "import random\nfrom pathlib import Path\n\nfrom tinker_cookbook.recipes.rubric.data import Rubric, RubricBasedDatapoint\n\n\ndef generate_one(rng: random.Random) -> RubricBasedDatapoint:\n    x, y = rng.randint(0, 1000), rng.randint(0, 1000)\n    return RubricBasedDatapoint(\n        convo=[\n            {\"role\": \"user\", \"content\": \"What is 4 + 5?\"},\n            {\"role\": \"assistant\", \"content\": \"9\"},\n            {\"role\": \"user\", \"content\": f\"What is {x} + {y}?\"},\n        ],\n        rubric_items=[Rubric(rubric_str=f\"Does the chatbot correctly get the answer {x + y}?\")],\n    )\n\n\ndef generate_dataset(\n    num_train: int, num_test: int, seed: int, write_dir: str = \"tinker_cookbook/example_data/\"\n) -> tuple[str, str]:\n    random.seed(seed)\n    rng = random.Random(seed)\n    total_datapoints = num_train + num_test\n    datapoints = [generate_one(rng) for _ in range(total_datapoints)]\n\n    write_path = Path(write_dir)\n    train_datapoints = datapoints[:num_train]\n    train_jsonl_path = str(write_path / \"example_rubric_train.jsonl\")\n    with open(train_jsonl_path, \"w\") as f:\n        for datapoint in train_datapoints:\n            f.write(datapoint.to_json() + \"\\n\")\n    print(f\"Generated {len(train_datapoints)} train datapoints in {train_jsonl_path}\")\n\n    test_datapoints = datapoints[num_train:]\n    test_jsonl_path = str(write_path / \"example_rubric_test.jsonl\")\n    with open(test_jsonl_path, \"w\") as f:\n        for datapoint in test_datapoints:\n            f.write(datapoint.to_json() + \"\\n\")\n    print(f\"Generated {len(test_datapoints)} test datapoints in {test_jsonl_path}\")\n\n    return train_jsonl_path, test_jsonl_path\n\n\nif __name__ == \"__main__\":\n    train_jsonl_path, test_jsonl_path = generate_dataset(num_train=10000, num_test=1000, seed=42)\n    print(f\"Generated train dataset in {train_jsonl_path}\")\n    print(f\"Generated test dataset in {test_jsonl_path}\")\n"
  },
  {
    "path": "tinker_cookbook/recipes/rubric/prometheus_experimental.py",
    "content": "import asyncio\nfrom datetime import datetime\n\nimport chz\nfrom tinker.types import LossFnType\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.recipes.rubric.data import PrometheusDatapointListBuilder\nfrom tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder\nfrom tinker_cookbook.rl.train import AsyncConfig, Config, main\nfrom tinker_cookbook.rl.types import RLDatasetBuilder\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Simple command-line configuration for RL training.\"\"\"\n\n    # Model configuration\n    model_name: str = \"meta-llama/Llama-3.1-8B-Instruct\"\n    lora_rank: int = 32\n    renderer_name: str | None = None\n    load_checkpoint_path: str | None = None\n\n    seed: int = 0  # Random seed for data shuffling\n\n    # Training hyperparameters\n    train_group_size: int = 4\n    test_group_size: int = 1\n    groups_per_batch: int = 100\n    learning_rate: float = 1e-5\n    max_tokens: int = 5\n    temperature: float = 1.0\n    kl_penalty_coef: float = 0.0\n    grader_llm_name: str = \"Qwen/Qwen3-30B-A3B-Instruct-2507\"\n    # Number of optimizer steps per training iteration.\n    # Useful for very large batch sizes.\n    num_substeps: int = 1\n\n    # Logging configuration\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    compute_post_kl: bool = False\n\n    # Evals\n    eval_every: int = 20\n\n    # Checkpointing\n    save_every: int = 20\n\n    # Service configuration\n    base_url: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps_off_policy: int | None = None\n    loss_fn: LossFnType = \"importance_sampling\"\n\n    max_steps: int | None = None\n\n\ndef get_dataset_builder(\n    batch_size: int,\n    policy_model_name: str,\n    renderer_name: str,\n    grader_llm_name: str,\n    train_group_size: int,\n    test_group_size: int = 1,\n) -> RLDatasetBuilder:\n    return RubricGradedDatasetBuilder(\n        batch_size=batch_size,\n        model_name_for_tokenizer=policy_model_name,\n        renderer_name=renderer_name,\n        grader_llm_name=grader_llm_name,\n        train_datapoint_list_builder=PrometheusDatapointListBuilder(),\n        test_datapoint_list_builder=None,\n        train_group_size=train_group_size,\n        test_group_size=test_group_size,\n    )\n\n\nasync def cli_main(cli_config: CLIConfig):\n    \"\"\"Convert CLI config to full config and run training.\"\"\"\n\n    # Get tokenizer for stop sequences\n    renderer_name = await checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n    model_name = cli_config.model_name.replace(\"/\", \"-\")\n    run_name = f\"prometheus_experimental-{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.train_group_size}group_size-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n    # create log path if it doesn't exist\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/rubric/{run_name}\"\n\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    # Create full config\n    config = Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_builder=get_dataset_builder(\n            batch_size=cli_config.groups_per_batch,\n            policy_model_name=cli_config.model_name,\n            renderer_name=renderer_name,\n            grader_llm_name=cli_config.grader_llm_name,\n            train_group_size=cli_config.train_group_size,\n            test_group_size=cli_config.test_group_size,\n        ),\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        temperature=cli_config.temperature,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        log_path=log_path,\n        base_url=cli_config.base_url,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        compute_post_kl=cli_config.compute_post_kl,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        num_substeps=cli_config.num_substeps,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        async_config=AsyncConfig(\n            max_steps_off_policy=cli_config.max_steps_off_policy,\n            groups_per_batch=cli_config.groups_per_batch,\n        )\n        if cli_config.max_steps_off_policy is not None\n        else None,\n        loss_fn=cli_config.loss_fn,\n        max_steps=cli_config.max_steps,\n    )\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    # Run training\n    await main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/rubric/train.py",
    "content": "import asyncio\nfrom datetime import datetime\n\nimport chz\nfrom tinker.types import LossFnType\n\nfrom tinker_cookbook import checkpoint_utils, cli_utils\nfrom tinker_cookbook.recipes.rubric.data import RubricDatapointListBuilderFromJsonl\nfrom tinker_cookbook.recipes.rubric.env import RubricGradedDatasetBuilder\nfrom tinker_cookbook.rl.train import AsyncConfig, Config, main\nfrom tinker_cookbook.rl.types import RLDatasetBuilder\n\n\n@chz.chz\nclass CLIConfig:\n    \"\"\"Simple command-line configuration for RL training.\"\"\"\n\n    # Model configuration\n    model_name: str = \"meta-llama/Llama-3.1-8B-Instruct\"\n    lora_rank: int = 32\n    renderer_name: str | None = None\n    load_checkpoint_path: str | None = None\n\n    seed: int = 0  # Random seed for data shuffling\n\n    # Training hyperparameters\n    train_group_size: int = 4\n    test_group_size: int = 1\n    groups_per_batch: int = 100\n    learning_rate: float = 1e-5\n    max_tokens: int = 5\n    temperature: float = 1.0\n    kl_penalty_coef: float = 0.0\n    grader_llm_name: str = \"Qwen/Qwen3-30B-A3B-Instruct-2507\"\n    train_jsonl_path: str = \"tinker_cookbook/example_data/example_rubric_train.jsonl\"\n    test_jsonl_path: str = \"tinker_cookbook/example_data/example_rubric_test.jsonl\"\n\n    # Number of optimizer steps per training iteration.\n    # Useful for very large batch sizes.\n    num_substeps: int = 1\n\n    # Logging configuration\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    compute_post_kl: bool = False\n\n    # Evals\n    eval_every: int = 20\n\n    # Checkpointing\n    save_every: int = 20\n\n    # Service configuration\n    base_url: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps_off_policy: int | None = None\n    loss_fn: LossFnType = \"importance_sampling\"\n\n    max_steps: int | None = None\n\n\ndef get_dataset_builder(\n    batch_size: int,\n    policy_model_name: str,\n    renderer_name: str,\n    grader_llm_name: str,\n    train_group_size: int,\n    train_jsonl_path: str,\n    test_jsonl_path: str | None = None,\n    test_group_size: int = 1,\n) -> RLDatasetBuilder:\n    return RubricGradedDatasetBuilder(\n        batch_size=batch_size,\n        model_name_for_tokenizer=policy_model_name,\n        renderer_name=renderer_name,\n        grader_llm_name=grader_llm_name,\n        train_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(\n            jsonl_path=train_jsonl_path\n        ),\n        test_datapoint_list_builder=RubricDatapointListBuilderFromJsonl(jsonl_path=test_jsonl_path)\n        if test_jsonl_path is not None\n        else None,\n        train_group_size=train_group_size,\n        test_group_size=test_group_size,\n    )\n\n\nasync def cli_main(cli_config: CLIConfig):\n    \"\"\"Convert CLI config to full config and run training.\"\"\"\n\n    # Get tokenizer for stop sequences\n    renderer_name = await checkpoint_utils.resolve_renderer_name_from_checkpoint_or_default_async(\n        model_name=cli_config.model_name,\n        explicit_renderer_name=cli_config.renderer_name,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        base_url=cli_config.base_url,\n    )\n    model_name = cli_config.model_name.replace(\"/\", \"-\")\n    run_name = f\"{model_name}-{cli_config.lora_rank}rank-{cli_config.learning_rate}lr-{cli_config.train_group_size}group_size-{cli_config.groups_per_batch}batch-{cli_config.loss_fn}-seed{cli_config.seed}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}\"\n    # create log path if it doesn't exist\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/rubric/{run_name}\"\n\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    # Create full config\n    config = Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_builder=get_dataset_builder(\n            batch_size=cli_config.groups_per_batch,\n            policy_model_name=cli_config.model_name,\n            renderer_name=renderer_name,\n            grader_llm_name=cli_config.grader_llm_name,\n            train_group_size=cli_config.train_group_size,\n            train_jsonl_path=cli_config.train_jsonl_path,\n            test_jsonl_path=cli_config.test_jsonl_path,\n            test_group_size=cli_config.test_group_size,\n        ),\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        lora_rank=cli_config.lora_rank,\n        max_tokens=cli_config.max_tokens,\n        temperature=cli_config.temperature,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        log_path=log_path,\n        base_url=cli_config.base_url,\n        load_checkpoint_path=cli_config.load_checkpoint_path,\n        compute_post_kl=cli_config.compute_post_kl,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        num_substeps=cli_config.num_substeps,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        async_config=AsyncConfig(\n            max_steps_off_policy=cli_config.max_steps_off_policy,\n            groups_per_batch=cli_config.groups_per_batch,\n        )\n        if cli_config.max_steps_off_policy is not None\n        else None,\n        loss_fn=cli_config.loss_fn,\n        max_steps=cli_config.max_steps,\n    )\n\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    # Run training\n    await main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/search_tool/README.md",
    "content": "# Replicating Search-R1 with Tinker\n\n[Search-R1](https://arxiv.org/pdf/2503.09516) is a recent paper that showcases tool-use RL for multi-hop QA on Wikipedia.\nIt provides a clean setup for testing tool-use RL and also released their training and evaluation data.\nIn this demo, we demonstrate similar experiments using `Qwen3-4B-Instruct-2507`, and we include our replication results using `Qwen/Qwen2.5-7B-Instruct` at the end.\n\n## Running This Demo\n\n### Installation and Setup\n\nThis demo is built with Chroma DB and the Gemini API. You can install the additional dependencies by\n\n```bash\nuv pip install -e .[vector-search]\n```\n\nBy default, we use google vertex ai for the embedding service, and you need to set `$GOOGLE_GENAI_USE_VERTEXAI`, `$GCP_VERTEXAI_PROJECT_NUMBER`, `$GCP_VERTEXAI_REGION`. Or, tweak `./embedding.py` to authenticate differently.\n\nCurrently, the tool use RL run relies on a separate Chroma vector search service. You can set it up with the following step:\n\n1. You can download a pre-computed wiki18 index: https://huggingface.co/datasets/tianyi-thinks/2018-wiki-index/blob/main/chroma_db.tar.xz\n2. Launch the Chroma service on localhost. Example command: `chroma run --host localhost --path <decompressed_path>/chroma_db --port 8000`\n\nIf you launch the chroma service locally, you generally need 160+ GB RAM to load the vector index in memory for good performance.\n\n### Example command\n\nThis default command trains a `Qwen3-4B-Instruct-2507` with reasonable hyperparameters.\n\n```bash\npython -m tinker_cookbook.recipes.search_tool.train\n```\n\nWith the default hyperparameters, you can expect performance like:\n| | Natural Questions | Trivia QA | HotpotQA | 2WikiMultihopQA |\n|---|---|---|---|---|\n| Qwen3-4B-Instruct-2507 | 51.8 | 70.2 | 52.0 | 47.7 |\n\nA successful run generally learns multi-turn search within 10-25 steps, which can be monitored by checking if `env/all/turns_per_episode` has increased over 2 turns.\n\n**Note:** The `max_trajectory_tokens` parameter (default: 32,768) limits the total context length for multi-turn interactions. If your searches require longer contexts, you can adjust it with `max_trajectory_tokens=<value>`.\n\nTo speed up training, you may consider turning on `--stream_minibatch`. In principle, this system improvement should have minimal effect on training.\n\n### Extensions: How to Include Other Tools?\n\n1. The tool call rendering / parsing logic is in [tinker_cookbook/renderers/](../../renderers/). Tool calling is supported on multiple renderers (Qwen, GPT-OSS, DeepSeek, Kimi). Currently, the system prompt necessary for enabling tool calling is included in `./search_env.py` (`SEARCH_TASK_INSTRUCTIONS`) and is written specifically for Qwen. Changing the tool calling parsing format requires updating the system prompt accordingly.\n2. Extend `./embedding.py` to replace the Gemini embedding.\n3. Extend `./tools.py` to add new tools using the `@tool` decorator - see `ChromaTool.search()` as an example.\n\n### Replication Results\n\nWe conducted experiments on a `Qwen/Qwen2.5-7B-Instruct` model and compared with the results reported in the original paper.\nNote this model is not available on Tinker and we chose it specifically to compare with the original paper.\nThe results can be seen here,\n\n|                | Natural Questions | Trivia QA | HotpotQA | 2WikiMultihopQA |\n| -------------- | ----------------- | --------- | -------- | --------------- |\n| original paper | 42.9              | 62.3      | 38.6     | 34.6            |\n| tinker         | **51.6**          | **67.3**  | **49.7** | **42.8**        |\n\nThe key differences between our experiment and the original paper include:\n\n1. We used the default importance-weighting REINFORCE loss implemented in Tinker\n2. We used the default synchronous rollout logic in the Tinker Cookbook\n3. We used Gemini embedding and Chroma DB, motivated by their ease of setup for a public demo. In exploratory experiments, the Gemini embedding does not improve RL performance over the E5 embedding model used in the original paper.\n\n[1] Jin, B., Zeng, H., Yue, Z., Yoon, J., Arık, S. O., Wang, D., Zamani, H., & Han, J. (2025). Search-R1: Training LLMs to reason and leverage search engines with reinforcement learning. arXiv preprint arXiv:2503.09516.\n"
  },
  {
    "path": "tinker_cookbook/recipes/search_tool/chroma_pickle_test.py",
    "content": "\"\"\"Tests for picklability of ChromaTool.\"\"\"\n\nimport pytest\n\ntry:\n    import chromadb as _chromadb  # noqa: F401\n\n    _has_chromadb = True\nexcept ImportError:\n    _has_chromadb = False\n\n\n@pytest.mark.skipif(not _has_chromadb, reason=\"chromadb not installed\")\nclass TestChromatoolPickle:\n    def test_pickle_excludes_clients(self) -> None:\n        \"\"\"ChromaTool excludes async clients from pickle state and preserves connection params.\"\"\"\n        from unittest.mock import MagicMock\n\n        from tinker_cookbook.recipes.search_tool.tools import ChromaTool, RetrievalConfig\n\n        tool = ChromaTool(\n            chroma_client=MagicMock(),\n            gemini_client=MagicMock(),\n            collection_name=\"wiki_chunks\",\n            retrieval_config=RetrievalConfig(),\n            max_retries=5,\n            initial_retry_delay=2,\n            chroma_host=\"localhost\",\n            chroma_port=8000,\n        )\n        state = tool.__getstate__()\n        assert state[\"_chroma_client\"] is None\n        assert state[\"_gemini_client\"] is None\n        assert state[\"_chroma_host\"] == \"localhost\"\n        assert state[\"_chroma_port\"] == 8000\n        assert state[\"_collection_name\"] == \"wiki_chunks\"\n        assert state[\"_max_retries\"] == 5\n"
  },
  {
    "path": "tinker_cookbook/recipes/search_tool/embedding.py",
    "content": "\"\"\"\nShared utilities for Gemini embedding generation with retry logic\n\"\"\"\n\nimport asyncio\nfrom logging import getLogger\nfrom os import environ\nfrom typing import Any\n\nimport google.genai as genai\nfrom google.genai import types\n\nlogger = getLogger(__name__)\n\n# Retry configuration - using the more conservative setting from query_wiki.py\nMAX_RETRIES = 10\nRETRY_DELAY = 1.0\n\n\ndef get_gemini_client(\n    *,\n    vertexai: bool | None = None,\n    project: str | None = None,\n    location: str | None = None,\n    http_options: types.HttpOptions | None = None,\n    **kwargs: Any,\n) -> genai.Client:\n    import google.genai as genai\n    from google.genai.types import HttpOptions\n\n    project = project or environ.get(\"GCP_VERTEXAI_PROJECT_NUMBER\")\n    if project is None:\n        raise ValueError(\"$GCP_VERTEXAI_PROJECT_NUMBER is not set\")\n\n    location = location or environ.get(\"GCP_VERTEXAI_REGION\")\n    if location is None:\n        raise ValueError(\"$GCP_VERTEXAI_REGION is not set\")\n\n    return genai.Client(\n        vertexai=(\n            environ.get(\"GOOGLE_GENAI_USE_VERTEXAI\", \"True\").lower().strip().startswith(\"t\")\n            if vertexai is None\n            else vertexai\n        ),\n        project=project,\n        location=location,\n        http_options=http_options or HttpOptions(api_version=\"v1\", timeout=10 * 1000),\n        **kwargs,\n    )\n\n\nasync def get_gemini_embedding(\n    client: genai.Client,\n    texts: list[str],\n    model: str = \"gemini-embedding-001\",\n    embedding_dim: int = 768,\n    task_type: str = \"RETRIEVAL_QUERY\",\n    max_retries: int = MAX_RETRIES,\n    retry_delay: float = RETRY_DELAY,\n) -> list[list[float]]:\n    \"\"\"\n    Get embeddings from Gemini API with exponential backoff retry logic.\n\n    Always takes a list of strings and returns a list of embeddings.\n\n    Args:\n        texts: List of texts to embed\n        model: Gemini embedding model name (default: \"gemini-embedding-001\")\n        embedding_dim: Desired embedding dimension (default: 768)\n        task_type: Embedding task type (default: \"RETRIEVAL_QUERY\")\n        max_retries: Maximum number of retries (default: 10)\n        retry_delay: Delay between retries (default: 1.0)\n\n    Returns:\n        List of embeddings (list of list of floats) -- guaranteed to be the same length as the input texts\n\n    Raises:\n        Exception: If embedding generation fails after all retries\n    \"\"\"\n    # Validate input\n    if not texts:\n        raise ValueError(\"No texts provided for embedding generation\")\n\n    for i, text in enumerate(texts):\n        if not isinstance(text, str):\n            raise ValueError(f\"Text at index {i} is not a string: {type(text)} = {text}\")\n        if not text.strip():\n            raise ValueError(f\"Text at index {i} is empty or whitespace only\")\n\n    # Retry logic with exponential backoff\n    for attempt in range(max_retries):\n        try:\n            async with asyncio.timeout(10):\n                response = await client.aio.models.embed_content(\n                    model=model,\n                    contents=texts,  # pyright: ignore - Pass the list of texts directly works\n                    config=types.EmbedContentConfig(\n                        task_type=task_type, output_dimensionality=embedding_dim\n                    ),\n                )\n\n            if response.embeddings is None or len(response.embeddings) == 0:\n                raise ValueError(\"No embeddings returned from Gemini API\")\n\n            if len(response.embeddings) != len(texts):\n                raise ValueError(\n                    f\"Mismatch: expected {len(texts)} embeddings, got {len(response.embeddings)}\"\n                )\n\n            # Extract embedding values\n            embeddings: list[list[float]] = []\n            for i, embedding in enumerate(response.embeddings):\n                if embedding.values is None:\n                    raise ValueError(f\"No embedding values returned for text {i}\")\n                embeddings.append(embedding.values)\n\n            return embeddings\n\n        except Exception as e:\n            if attempt < max_retries - 1:\n                wait_time = retry_delay * (1.5**attempt)  # Exponential backoff\n                logger.error(\n                    f\"Attempt {attempt + 1}/{max_retries} failed for embedding ({len(texts)} texts): {e!r}. Retrying in {wait_time:.1f}s...\"\n                )\n                await asyncio.sleep(wait_time)\n            else:\n                logger.error(\n                    f\"All {max_retries} attempts failed for embedding ({len(texts)} texts): {e!r}\"\n                )\n                raise\n\n    # This should never be reached due to the raise above, but satisfies type checker\n    raise RuntimeError(\"Unexpected error in retry logic\")\n"
  },
  {
    "path": "tinker_cookbook/recipes/search_tool/offline_eval.py",
    "content": "import asyncio\nimport random\nfrom collections import defaultdict\nfrom typing import Literal, TypedDict\n\nimport chz\nimport tinker\n\nfrom tinker_cookbook import checkpoint_utils, model_info, tokenizer_utils\nfrom tinker_cookbook.completers import TinkerTokenCompleter\nfrom tinker_cookbook.recipes.search_tool.search_env import (\n    SEARCH_TASK_INSTRUCTIONS,\n    SearchR1Datum,\n    download_search_r1_dataset,\n)\nfrom tinker_cookbook.recipes.search_tool.tools import (\n    ChromaTool,\n    EmbeddingConfig,\n    RetrievalConfig,\n    TextAnswerReward,\n)\nfrom tinker_cookbook.renderers import Renderer, get_renderer\nfrom tinker_cookbook.rl.rollouts import do_single_rollout\nfrom tinker_cookbook.tool_use import build_agent_tool_env\n\nROLLOUT_CONCURRENCY = 1024\nrollout_semaphore = asyncio.Semaphore(ROLLOUT_CONCURRENCY)\n\n\n@chz.chz\nclass CLIConfig:\n    # Evaluation parameters\n    max_eval_samples: int = chz.field(\n        default=100, doc=\"Maximum number of samples to evaluate per data source\"\n    )\n    seed: int = chz.field(default=42, doc=\"Random seed for sampling\")\n    split: Literal[\"train\", \"test\"] = chz.field(default=\"test\", doc=\"Dataset split to use\")\n\n    # Model parameters\n    base_model: str = chz.field(default=\"Qwen/Qwen3-4B-Instruct-2507\", doc=\"Base model to use\")\n    tinker_checkpoint_url: str = chz.field(doc=\"Tinker checkpoint URL (required)\")\n    max_tokens: int = chz.field(default=1024, doc=\"Maximum number of tokens to generate\")\n\n\nclass EvaluationResult(TypedDict):\n    question: str\n    correct_score: float\n    trajectory: object\n\n\ndef split_data_by_source(data: list[SearchR1Datum]) -> dict[str, list[SearchR1Datum]]:\n    \"\"\"Split data by data source.\"\"\"\n    data_by_source = defaultdict(list)\n    for item in data:\n        data_by_source[item[\"data_source\"]].append(item)\n    return dict(data_by_source)\n\n\ndef sample_k_from_each_source(\n    data_by_source: dict[str, list[SearchR1Datum]], k: int, seed: int = 42\n) -> dict[str, list[SearchR1Datum]]:\n    \"\"\"Sample K items from each data source.\"\"\"\n    random.seed(seed)\n    sampled_data = {}\n    total_samples = 0\n\n    for source, items in data_by_source.items():\n        if len(items) <= k:\n            sampled_data[source] = items\n        else:\n            sampled_data[source] = random.sample(items, k)\n        total_samples += len(sampled_data[source])\n        print(f\"{source}: {len(items)} -> {len(sampled_data[source])} samples\")\n\n    print(f\"Total samples: {total_samples}\")\n    return sampled_data\n\n\nasync def evaluate_single_item(\n    item: SearchR1Datum,\n    chroma_tool: ChromaTool,\n    policy: TinkerTokenCompleter,\n    renderer: Renderer,\n) -> EvaluationResult:\n    tool_schemas = [chroma_tool.search.to_spec()]\n    initial_messages = renderer.create_conversation_prefix_with_tools(\n        tools=tool_schemas,\n        system_prompt=SEARCH_TASK_INSTRUCTIONS,\n    ) + [{\"role\": \"user\", \"content\": item[\"question\"]}]\n\n    env = build_agent_tool_env(\n        renderer=renderer,\n        tools=[chroma_tool.search],\n        initial_messages=initial_messages,\n        reward_fn=TextAnswerReward(gold_answers=item[\"answer\"], format_coef=0.1),\n        max_turns=5,\n    )\n    async with rollout_semaphore:\n        trajectory = await do_single_rollout(policy, env)\n\n    # Extract correct metric from the last transition\n    correct_score = 0.0\n    if trajectory.transitions:\n        correct_score = trajectory.transitions[-1].metrics.get(\"correct\", 0.0)\n\n    return {\"question\": item[\"question\"], \"correct_score\": correct_score, \"trajectory\": trajectory}\n\n\nasync def evaluate_one_dataset(data: list[SearchR1Datum], config: CLIConfig):\n    # Load model\n    service_client = tinker.ServiceClient()\n    sampling_client = service_client.create_sampling_client(model_path=config.tinker_checkpoint_url)\n    policy = TinkerTokenCompleter(sampling_client, max_tokens=config.max_tokens)\n\n    tokenizer = tokenizer_utils.get_tokenizer(config.base_model)\n    renderer_name = await checkpoint_utils.get_renderer_name_from_checkpoint_async(\n        service_client, config.tinker_checkpoint_url\n    )\n    if renderer_name is None:\n        renderer_name = model_info.get_recommended_renderer_name(config.base_model)\n    print(f\"Using renderer: {renderer_name}\")\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    chroma_tool = await ChromaTool.build(\n        chroma_host=\"localhost\",\n        chroma_port=8000,\n        chroma_collection_name=\"wiki_embeddings\",\n        retrieval_config=RetrievalConfig(\n            n_results=3,\n            embedding_config=EmbeddingConfig(\n                model_name=\"gemini-embedding-001\",\n                embedding_dim=768,\n            ),\n        ),\n    )\n\n    # Run evaluations in parallel using asyncio.gather\n    tasks = [evaluate_single_item(item, chroma_tool, policy, renderer) for item in data]\n\n    print(f\"Evaluating {len(tasks)} items\")\n    results = await asyncio.gather(*tasks)\n\n    # Aggregate results\n    correct_scores = [result[\"correct_score\"] for result in results]\n\n    if correct_scores:\n        total_correct = sum(correct_scores)\n        accuracy = total_correct / len(correct_scores)\n        return {\n            \"total_samples\": len(correct_scores),\n            \"total_correct\": total_correct,\n            \"accuracy\": accuracy,\n        }\n\n    return {\"total_samples\": 0, \"total_correct\": 0, \"accuracy\": 0.0}\n\n\nasync def cli_main(config: CLIConfig):\n    # Download the data\n    print(f\"Downloading {config.split} split...\")\n    data = download_search_r1_dataset(config.split)\n    print(f\"Total data points: {len(data)}\")\n\n    # Split by data source\n    data_by_source = split_data_by_source(data)\n    print(f\"\\nData sources found: {list(data_by_source.keys())}\")\n    print(\"Original distribution:\")\n    for source, items in data_by_source.items():\n        print(f\"  {source}: {len(items)}\")\n\n    # Sample K from each source\n    print(f\"\\nSampling up to {config.max_eval_samples} samples from each source...\")\n    sampled_data_by_source = sample_k_from_each_source(\n        data_by_source, config.max_eval_samples, config.seed\n    )\n\n    # Collect results from all datasets\n    dataset_results = {}\n    for source, data in sampled_data_by_source.items():\n        print(f\"Evaluating {source}...\")\n        result = await evaluate_one_dataset(data, config)\n        dataset_results[source] = result\n\n    # Print results table\n    print(\"\\n\" + \"=\" * 80)\n    print(\"EVALUATION RESULTS\")\n    print(\"=\" * 80)\n    print(f\"{'Dataset':<15} {'Accuracy':<10} {'Correct':<10} {'Total':<10}\")\n    print(\"-\" * 80)\n\n    total_all_correct = 0\n    total_all_samples = 0\n\n    for dataset, result in dataset_results.items():\n        accuracy = result[\"accuracy\"]\n        correct = result[\"total_correct\"]\n        total = result[\"total_samples\"]\n        total_all_correct += correct\n        total_all_samples += total\n        print(f\"{dataset:<15} {accuracy:<10.3f} {correct:<10.0f} {total:<10}\")\n\n    if total_all_samples > 0:\n        overall_accuracy = total_all_correct / total_all_samples\n        print(\"-\" * 80)\n        print(\n            f\"{'OVERALL':<15} {overall_accuracy:<10.3f} {total_all_correct:<10.0f} {total_all_samples:<10}\"\n        )\n    print(\"=\" * 80)\n\n\nif __name__ == \"__main__\":\n    config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/search_tool/search_env.py",
    "content": "from __future__ import annotations\n\nimport os\nimport random\nfrom collections.abc import Sequence\nfrom pathlib import Path\nfrom typing import Literal, TypedDict, cast\n\nimport chz\nimport pandas as pd\nfrom huggingface_hub import hf_hub_download\n\nfrom tinker_cookbook import model_info, tokenizer_utils\nfrom tinker_cookbook.recipes.search_tool.tools import (\n    ChromaTool,\n    RetrievalConfig,\n    TextAnswerReward,\n)\nfrom tinker_cookbook.renderers import get_renderer\nfrom tinker_cookbook.renderers.base import Message, Renderer\nfrom tinker_cookbook.rl.types import Env, EnvGroupBuilder, RLDataset, RLDatasetBuilder\nfrom tinker_cookbook.tool_use import build_agent_tool_env\n\nSEARCH_TASK_INSTRUCTIONS = \"\"\"You are an expert assistant who solves tasks using a Wikipedia search tool.\n\nHere are instructions for how to solve a problem:\n1. Think step by step before calling the tool and after you receive the result of the tool call. Decide what queries to call the tool with.\n2. Call the tool with the queries you have decided on.\n3. Think step by step again after you receive the result of the tool call. If you have the information you need, you can stop here.\n4. Otherwise, come up with new queries that combine information from the previous results.\n5. Include your final answer after the \"Answer:\" prefix. The answer should be between one to five words.\n\nHere is an example of solving a real question:\n\"Between 2020 and 2025, which year did New York City see the most population growth and how did San Francisco population change in that year?\"\n\n1. Think step by step: In order to answer this question, I need to know the population of New York City and San Francisco between 2020 and 2025. I will search for the population of New York City in each year\n2. Calling search tool: <tool_call>{\"name\": \"search\", \"arguments\": {\"query_list\": [\"Population New York city between 2020 and 2025\"]}}</tool_call> (Output omitted for brevity)\n3. Think step by step again: I have the population of New York City in each year, and I see that the population of New York City grew the most in 2024. I need to know the population of San Francisco in 2024. I will search for the population of San Francisco in each year.\n<tool_call>{\"name\": \"search\", \"arguments\": {\"query_list\": [\"Population San Francisco between 2023 and 2024\"]}}</tool_call> (Output omitted for brevity)\n4. Answer: The population of New York City grew the most in 2024, and the population of San Francisco changed by XXXX in 2024.\n\"\"\"\n\n\nclass SearchR1Datum(TypedDict):\n    question: str\n    answer: list[str]\n    data_source: str\n\n\ndef process_single_row(row_series: pd.Series) -> SearchR1Datum:\n    \"\"\"\n    Process a single row of data for SearchR1-like format.\n\n    Args:\n        row_series: DataFrame row containing the original data\n\n    Returns:\n        SearchR1Datum: Processed row data in the required format\n    \"\"\"\n    import numpy as np\n\n    row = row_series.to_dict()\n    question: str = row.get(\"question\", \"\")\n\n    # Extract ground truth from reward_model or fallback to golden_answers\n    reward_model_data = row.get(\"reward_model\")\n    if isinstance(reward_model_data, dict) and \"ground_truth\" in reward_model_data:\n        ground_truth = reward_model_data.get(\"ground_truth\")\n    else:\n        ground_truth = row.get(\"golden_answers\", [])\n\n    # NOTE(tianyi)\n    # I hate datasets with mixed types but it is what it is.\n    if isinstance(ground_truth, dict):\n        ground_truth = ground_truth[\"target\"]\n    if isinstance(ground_truth, np.ndarray):\n        ground_truth = ground_truth.tolist()\n\n    assert isinstance(ground_truth, list)\n    for item in ground_truth:\n        assert isinstance(item, str)\n    ground_truth = cast(list[str], ground_truth)\n    return {\n        \"question\": question,\n        \"answer\": ground_truth,\n        \"data_source\": row[\"data_source\"],\n    }\n\n\ndef download_search_r1_dataset(split: Literal[\"train\", \"test\"]) -> list[SearchR1Datum]:\n    hf_repo_id: str = \"PeterJinGo/nq_hotpotqa_train\"\n    parquet_filename: str = f\"{split}.parquet\"\n    # TODO(tianyi): make download dir configurable for release\n    user = os.getenv(\"USER\", \"unknown\")\n    assert user is not None\n    tmp_download_dir = Path(\"/tmp\") / user / \"data\" / hf_repo_id / split\n    tmp_download_dir.mkdir(parents=True, exist_ok=True)\n\n    local_parquet_filepath = hf_hub_download(\n        repo_id=hf_repo_id,\n        filename=parquet_filename,\n        repo_type=\"dataset\",\n        local_dir=tmp_download_dir,\n    )\n\n    df_raw = pd.read_parquet(local_parquet_filepath)\n\n    return df_raw.apply(process_single_row, axis=1).tolist()\n\n\ndef _initial_messages(\n    datum: SearchR1Datum,\n    renderer: Renderer,\n    chroma_tool: ChromaTool,\n) -> list[Message]:\n    \"\"\"Build initial messages with tool schemas and task question.\"\"\"\n    tool_schemas = [chroma_tool.search.to_spec()]\n    prefix = renderer.create_conversation_prefix_with_tools(\n        tools=tool_schemas,\n        system_prompt=SEARCH_TASK_INSTRUCTIONS,\n    )\n    return prefix + [{\"role\": \"user\", \"content\": datum[\"question\"]}]\n\n\nclass SearchEnvGroupBuilder(EnvGroupBuilder):\n    \"\"\"EnvGroupBuilder that creates search environments with a shared ChromaTool.\"\"\"\n\n    def __init__(\n        self,\n        datum: SearchR1Datum,\n        model_name: str,\n        renderer_name: str | None,\n        max_turns: int,\n        group_size: int,\n        chroma_tool: ChromaTool,\n        format_coef: float = 0.1,\n        max_trajectory_tokens: int = 32 * 1024,\n    ):\n        self.datum = datum\n        self.model_name = model_name\n        self.renderer_name = renderer_name\n        self.max_turns = max_turns\n        self.group_size = group_size\n        self.chroma_tool = chroma_tool\n        self.format_coef = format_coef\n        self.max_trajectory_tokens = max_trajectory_tokens\n\n    async def make_envs(self) -> Sequence[Env]:\n        tokenizer = tokenizer_utils.get_tokenizer(self.model_name)\n        renderer_name = self.renderer_name or model_info.get_recommended_renderer_name(\n            self.model_name\n        )\n        renderer = get_renderer(renderer_name, tokenizer)\n\n        # Tool, initial_messages, reward_fn are all stateless - can share\n        initial_messages = _initial_messages(self.datum, renderer, self.chroma_tool)\n        reward_fn = TextAnswerReward(\n            gold_answers=self.datum[\"answer\"], format_coef=self.format_coef\n        )\n\n        return [\n            build_agent_tool_env(\n                renderer=renderer,\n                tools=[self.chroma_tool.search],\n                initial_messages=initial_messages,\n                reward_fn=reward_fn,\n                max_turns=self.max_turns,\n                max_trajectory_tokens=self.max_trajectory_tokens,\n            )\n            for _ in range(self.group_size)\n        ]\n\n    def logging_tags(self) -> list[str]:\n        return [self.datum.get(\"data_source\", \"unknown\")]\n\n\nclass SearchRLDataset(RLDataset):\n    \"\"\"Dataset that processes search EnvGroupBuilders once per epoch.\"\"\"\n\n    def __init__(\n        self,\n        env_group_builders: list[SearchEnvGroupBuilder],\n        batch_size: int,\n    ):\n        self.env_group_builders = env_group_builders\n        self.batch_size = batch_size\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        start = index * self.batch_size\n        end = start + self.batch_size\n        return self.env_group_builders[start:end]\n\n    def __len__(self) -> int:\n        return len(self.env_group_builders) // self.batch_size\n\n\n@chz.chz\nclass SearchR1DatasetBuilder(RLDatasetBuilder):\n    \"\"\"Build an RL dataset over SearchR1 tasks with ChromaTool.\"\"\"\n\n    model_name_for_tokenizer: str\n    # ChromaTool connection params\n    chroma_host: str\n    chroma_port: int\n    chroma_collection_name: str\n    retrieval_config: RetrievalConfig = RetrievalConfig()\n    # Dataset params\n    batch_size: int\n    group_size: int\n    renderer_name: str | None = None\n    max_turns: int = 5\n    format_coef: float = 0.1\n    max_trajectory_tokens: int = 32 * 1024\n    seed: int = 0\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset | None]:\n        # Create shared ChromaTool\n        chroma_tool = await ChromaTool.build(\n            chroma_host=self.chroma_host,\n            chroma_port=self.chroma_port,\n            chroma_collection_name=self.chroma_collection_name,\n            retrieval_config=self.retrieval_config,\n        )\n\n        data = download_search_r1_dataset(\"train\")\n        rng = random.Random(self.seed)\n        rng.shuffle(data)\n\n        env_builders = [\n            SearchEnvGroupBuilder(\n                datum=datum,\n                model_name=self.model_name_for_tokenizer,\n                renderer_name=self.renderer_name,\n                max_turns=self.max_turns,\n                group_size=self.group_size,\n                chroma_tool=chroma_tool,\n                format_coef=self.format_coef,\n                max_trajectory_tokens=self.max_trajectory_tokens,\n            )\n            for datum in data\n        ]\n        dataset = SearchRLDataset(\n            env_group_builders=env_builders,\n            batch_size=self.batch_size,\n        )\n        return dataset, None\n"
  },
  {
    "path": "tinker_cookbook/recipes/search_tool/tools.py",
    "content": "from __future__ import annotations\n\nimport asyncio\nimport logging\nimport re\nimport string\nfrom dataclasses import dataclass\nfrom functools import reduce\nfrom typing import Annotated\n\nimport chromadb\nimport chz\nimport google.genai as genai\nfrom chromadb.api import AsyncClientAPI\nfrom chromadb.api.types import QueryResult\nfrom chromadb.config import Settings\n\nfrom tinker_cookbook.recipes.search_tool.embedding import (\n    get_gemini_client,\n    get_gemini_embedding,\n)\nfrom tinker_cookbook.renderers import get_text_content\nfrom tinker_cookbook.renderers.base import Message\nfrom tinker_cookbook.tool_use import ToolResult, simple_tool_result, tool\n\n\ndef normalize_answer(s: str) -> str:\n    \"\"\"Normalize answer by lowercasing, removing punctuation, articles, and fixing whitespace.\"\"\"\n\n    def remove_articles(text: str) -> str:\n        return re.sub(r\"\\b(a|an|the)\\b\", \" \", text)\n\n    def white_space_fix(text: str) -> str:\n        return \" \".join(text.split())\n\n    def remove_punc(text: str) -> str:\n        exclude = set(string.punctuation)\n        return \"\".join(ch for ch in text if ch not in exclude)\n\n    def lower(text: str) -> str:\n        return text.lower()\n\n    # Apply transformations in order using reduce\n    transformations = [lower, remove_punc, remove_articles, white_space_fix]\n    return reduce(lambda text, func: func(text), transformations, s)\n\n\nlogger = logging.getLogger(__name__)\n\n_CONNECTION_SEMAPHORE = asyncio.Semaphore(128)\n\n\n@chz.chz\nclass EmbeddingConfig:\n    model_name: str = \"gemini-embedding-001\"\n    embedding_dim: int = 768\n    task_type: str = \"RETRIEVAL_QUERY\"\n\n\n@chz.chz\nclass RetrievalConfig:\n    n_results: int = 3\n    embedding_config: EmbeddingConfig = EmbeddingConfig()\n\n\nclass ChromaTool:\n    \"\"\"Search tool using ChromaDB + Gemini embeddings.\n\n    Pickle support: async clients are not pickleable (network connections).\n    ``__getstate__`` excludes them; ``_ensure_clients()`` lazily reconnects\n    before first use after deserialization. Requires ``build()`` so that\n    connection params (host, port) are available for reconnection.\n    \"\"\"\n\n    def __init__(\n        self,\n        chroma_client: AsyncClientAPI,\n        gemini_client: genai.Client,\n        collection_name: str,\n        retrieval_config: RetrievalConfig,\n        max_retries: int,\n        initial_retry_delay: int,\n        # Connection params stored for reconnection after pickle roundtrip.\n        # Set automatically by build(); None if constructed directly.\n        chroma_host: str | None = None,\n        chroma_port: int | None = None,\n    ):\n        self._chroma_client: AsyncClientAPI | None = chroma_client\n        self._gemini_client: genai.Client | None = gemini_client\n        self._collection_name = collection_name\n        self._retrieval_config = retrieval_config\n        self._max_retries = max_retries\n        self._initial_retry_delay = initial_retry_delay\n        self._chroma_host = chroma_host\n        self._chroma_port = chroma_port\n\n    def __getstate__(self) -> dict:\n        \"\"\"Exclude non-pickleable async clients from pickle state.\"\"\"\n        state = self.__dict__.copy()\n        state[\"_chroma_client\"] = None\n        state[\"_gemini_client\"] = None\n        return state\n\n    async def _ensure_clients(self) -> tuple[AsyncClientAPI, genai.Client]:\n        \"\"\"Return live clients, reconnecting if needed after deserialization.\"\"\"\n        if self._chroma_client is None:\n            if self._chroma_host is None or self._chroma_port is None:\n                raise RuntimeError(\n                    \"Cannot reconnect ChromaTool: connection params not set. \"\n                    \"Use ChromaTool.build() to enable pickle support.\"\n                )\n            self._chroma_client = await chromadb.AsyncHttpClient(\n                host=self._chroma_host,\n                port=self._chroma_port,\n                settings=Settings(anonymized_telemetry=False),\n            )\n        if self._gemini_client is None:\n            self._gemini_client = get_gemini_client()\n        return self._chroma_client, self._gemini_client\n\n    @staticmethod\n    async def build(\n        chroma_host: str,\n        chroma_port: int,\n        chroma_collection_name: str,\n        retrieval_config: RetrievalConfig = RetrievalConfig(),\n        max_retries: int = 10,\n        initial_retry_delay: int = 1,\n        # Optional shared resources - None means build your own\n        chroma_client: AsyncClientAPI | None = None,\n        gemini_client: genai.Client | None = None,\n    ) -> ChromaTool:\n        \"\"\"Async factory for building ChromaTool.\n\n        Args:\n            chroma_host: ChromaDB server host.\n            chroma_port: ChromaDB server port.\n            chroma_collection_name: Name of the ChromaDB collection to query.\n            retrieval_config: Configuration for retrieval (n_results, embedding settings).\n            max_retries: Max retries for ChromaDB queries.\n            initial_retry_delay: Initial delay between retries (exponential backoff).\n            chroma_client: Optional pre-built ChromaDB client (for sharing across tools).\n            gemini_client: Optional pre-built Gemini client (for sharing across tools).\n        \"\"\"\n        if chroma_client is None:\n            chroma_client = await chromadb.AsyncHttpClient(\n                host=chroma_host,\n                port=chroma_port,\n                settings=Settings(anonymized_telemetry=False),\n            )\n        if gemini_client is None:\n            gemini_client = get_gemini_client()\n        return ChromaTool(\n            chroma_client,\n            gemini_client,\n            chroma_collection_name,\n            retrieval_config,\n            max_retries,\n            initial_retry_delay,\n            chroma_host=chroma_host,\n            chroma_port=chroma_port,\n        )\n\n    async def _get_embeddings_with_retry(\n        self, gemini_client: genai.Client, query_list: list[str]\n    ) -> list[list[float]]:\n        embedding_config = self._retrieval_config.embedding_config\n        return await get_gemini_embedding(\n            gemini_client,\n            query_list,\n            embedding_config.model_name,\n            embedding_config.embedding_dim,\n            embedding_config.task_type,\n        )\n\n    async def _query_chroma_with_retry(\n        self, chroma_client: AsyncClientAPI, query_embeddings: list[list[float]]\n    ) -> QueryResult:\n        for attempt in range(self._max_retries):\n            collection = await chroma_client.get_collection(self._collection_name)\n            try:\n                results = await collection.query(\n                    query_embeddings=query_embeddings,  # pyright: ignore[reportArgumentType]\n                    n_results=self._retrieval_config.n_results,\n                )\n                return results\n            except Exception as e:\n                if attempt < self._max_retries - 1:\n                    wait_time = self._initial_retry_delay * (1.5**attempt)\n                    logger.error(\n                        f\"ChromaDB query attempt {attempt + 1}/{self._max_retries} \"\n                        f\"failed: {e}. Retrying in {wait_time}s...\"\n                    )\n                    await asyncio.sleep(wait_time)\n                    continue\n                raise e\n\n        raise RuntimeError(\"All ChromaDB query attempts failed\")\n\n    @tool\n    async def search(\n        self,\n        query_list: Annotated[\n            list[str],\n            \"A list of fully-formed semantic queries. The tool will return search results for each query.\",\n        ],\n    ) -> ToolResult:\n        \"\"\"Search Wikipedia for relevant information based on the given query.\"\"\"\n        chroma_client, gemini_client = await self._ensure_clients()\n        async with _CONNECTION_SEMAPHORE:\n            embeddings = await self._get_embeddings_with_retry(gemini_client, query_list)\n            results = await self._query_chroma_with_retry(chroma_client, embeddings)\n\n        # Format same as original ChromaToolClient.invoke()\n        message_content = \"\"\n        documents_list = results[\"documents\"] or []\n        for query, documents in zip(query_list, documents_list):\n            message_content += f\"Query: {query}\\n\"\n            for doc_i, doc in enumerate(documents):\n                message_content += f\"Document {doc_i + 1}:\\n\"\n                message_content += f\"{doc}\\n\"\n\n        return simple_tool_result(message_content)\n\n\n@dataclass\nclass TextAnswerReward:\n    \"\"\"Reward function to check text answer against gold answers.\n\n    formula: format_coef * (correct_format - 1) + correct_answer\n    \"\"\"\n\n    gold_answers: list[str]\n    format_coef: float = 0.1\n\n    async def __call__(self, history: list[Message]) -> tuple[float, dict[str, float]]:\n        \"\"\"Grade the completed episode by checking the final assistant message.\"\"\"\n        # Find the last assistant message\n        final_message = None\n        for msg in reversed(history):\n            if msg.get(\"role\") == \"assistant\":\n                final_message = msg\n                break\n\n        if final_message is None:\n            return 0.0, {\"format\": 0.0, \"correct\": 0.0}\n\n        # Use get_text_content to properly handle thinking models (o1, o3)\n        content = get_text_content(final_message)\n\n        correct_format = float(self._extract_answer(content) is not None)\n        correct_answer = float(self._check_answer(content))\n\n        reward = self.format_coef * (correct_format - 1) + correct_answer\n        return reward, {\"format\": correct_format, \"correct\": correct_answer}\n\n    def _extract_answer(self, text: str) -> str | None:\n        if \"Answer:\" not in text:\n            return None\n        parts = text.split(\"Answer:\")\n        if len(parts) != 2:\n            return None\n        return parts[1].strip()\n\n    def _check_answer(self, text: str) -> bool:\n        model_answer = self._extract_answer(text)\n        if model_answer is None or len(self.gold_answers) == 0:\n            return False\n        for gold in self.gold_answers:\n            if normalize_answer(model_answer) == normalize_answer(gold):\n                return True\n        return False\n"
  },
  {
    "path": "tinker_cookbook/recipes/search_tool/train.py",
    "content": "\"\"\"CLI for Search-R1 replication.\"\"\"\n\nimport asyncio\nfrom datetime import datetime\nfrom pathlib import Path\n\nimport chz\n\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.recipes.search_tool.search_env import SearchR1DatasetBuilder\nfrom tinker_cookbook.recipes.search_tool.tools import (\n    EmbeddingConfig,\n    RetrievalConfig,\n)\nfrom tinker_cookbook.rl import train\n\n\n@chz.chz\nclass CLIConfig:\n    # Model parameters\n    model_name: str = \"Qwen/Qwen3-4B-Instruct-2507\"\n    lora_rank: int = 32\n    renderer_name: str | None = None\n\n    # Training parameters\n    learning_rate: float = 4e-5\n    batch_size: int = 512\n    seed: int = 2\n    max_tokens: int = 1024\n    eval_every: int = 0\n\n    # Dataset parameters\n    group_size: int = 8\n    max_turns: int = 5\n    format_coef: float = 0.1\n    max_trajectory_tokens: int = 32 * 1024\n\n    # Chroma configuration\n    chroma_host: str = \"localhost\"\n    chroma_port: int = 8000\n    chroma_collection_name: str = \"wiki_embeddings\"\n    n_results: int = 3\n    embedding_model_name: str = \"gemini-embedding-001\"\n    embedding_dim: int = 768\n\n    # Streaming configuration\n    stream_minibatch: bool = False\n    num_minibatches: int = 4\n\n    # Logging parameters\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\nasync def cli_main(cli_config: CLIConfig) -> None:\n    # Build retrieval config\n    retrieval_config = RetrievalConfig(\n        n_results=cli_config.n_results,\n        embedding_config=EmbeddingConfig(\n            model_name=cli_config.embedding_model_name,\n            embedding_dim=cli_config.embedding_dim,\n        ),\n    )\n\n    # Get renderer name\n    renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(\n        cli_config.model_name\n    )\n\n    builder = SearchR1DatasetBuilder(\n        batch_size=cli_config.batch_size,\n        group_size=cli_config.group_size,\n        renderer_name=renderer_name,\n        model_name_for_tokenizer=cli_config.model_name,\n        chroma_host=cli_config.chroma_host,\n        chroma_port=cli_config.chroma_port,\n        chroma_collection_name=cli_config.chroma_collection_name,\n        retrieval_config=retrieval_config,\n        seed=cli_config.seed,\n        max_turns=cli_config.max_turns,\n        format_coef=cli_config.format_coef,\n        max_trajectory_tokens=cli_config.max_trajectory_tokens,\n    )\n\n    # Configure streaming minibatch\n    if cli_config.stream_minibatch:\n        stream_minibatch_config = train.StreamMinibatchConfig(\n            groups_per_batch=cli_config.batch_size,\n            num_minibatches=cli_config.num_minibatches,\n        )\n        bs_str = f\"bs{cli_config.batch_size}_stream\"\n    else:\n        stream_minibatch_config = None\n        bs_str = f\"bs{cli_config.batch_size}\"\n\n    # Build run name\n    model_name_short = cli_config.model_name.lower().replace(\"/\", \"-\")\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    run_name = (\n        f\"search_r1_{model_name_short}_{bs_str}_gs{cli_config.group_size}_\"\n        f\"seed{cli_config.seed}_lr{cli_config.learning_rate}_\"\n        f\"rank{cli_config.lora_rank}_{date_and_time}\"\n    )\n\n    # Set log path\n    if cli_config.log_path is not None:\n        log_path = cli_config.log_path\n    else:\n        log_path = f\"/tmp/tinker-examples/rl_search/{run_name}\"\n\n    if cli_config.wandb_name is not None:\n        wandb_name = cli_config.wandb_name\n    else:\n        wandb_name = run_name\n\n    # Validate /tmp exists\n    if not Path(\"/tmp\").exists():\n        raise ValueError(\"/tmp does not exist\")\n\n    # Check log directory\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    # Build training config\n    config = train.Config(\n        model_name=cli_config.model_name,\n        renderer_name=renderer_name,\n        log_path=log_path,\n        dataset_builder=builder,\n        learning_rate=cli_config.learning_rate,\n        max_tokens=cli_config.max_tokens,\n        eval_every=cli_config.eval_every,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=wandb_name,\n        lora_rank=cli_config.lora_rank,\n        stream_minibatch_config=stream_minibatch_config,\n        max_steps=cli_config.max_steps,\n    )\n\n    # Run training\n    await train.main(config)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config))\n"
  },
  {
    "path": "tinker_cookbook/recipes/sl_basic.py",
    "content": "import asyncio\nimport sys\n\nimport chz\n\nfrom tinker_cookbook import cli_utils, model_info\nfrom tinker_cookbook.recipes.chat_sl import chat_datasets\nfrom tinker_cookbook.renderers import TrainOnWhat\nfrom tinker_cookbook.supervised import train\nfrom tinker_cookbook.supervised.data import FromConversationFileBuilder\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilderCommonConfig\n\n\ndef build_config_blueprint() -> chz.Blueprint[train.Config]:\n    model_name = \"meta-llama/Llama-3.1-8B\"\n    renderer_name = model_info.get_recommended_renderer_name(model_name)\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=model_name,\n        renderer_name=renderer_name,\n        max_length=32768,\n        batch_size=128,\n        train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES,\n    )\n    dataset = chat_datasets.NoRobotsBuilder(common_config=common_config)\n    if 0:  # To swap in your own dataset:\n        dataset = FromConversationFileBuilder(\n            common_config=common_config, file_path=\"/path/to/your/dataset.jsonl\"\n        )\n        # ^^^ Create a dataset from a JSONL file in the same format as\n        # tinker_cookbook/example_data/conversations.jsonl\n    return chz.Blueprint(train.Config).apply(\n        {\n            \"log_path\": \"/tmp/tinker-examples/sl_basic\",\n            \"model_name\": model_name,\n            \"renderer_name\": renderer_name,\n            \"dataset_builder\": dataset,\n            \"learning_rate\": 2e-4,\n            \"lr_schedule\": \"linear\",\n            \"num_epochs\": 1,\n            \"eval_every\": 8,\n        }\n    )\n\n\ndef main(config: train.Config):\n    # Avoid clobbering log dir from your previous run:\n    cli_utils.check_log_dir(config.log_path, behavior_if_exists=\"ask\")\n    asyncio.run(train.main(config))\n\n\nif __name__ == \"__main__\":\n    blueprint = build_config_blueprint()\n    blueprint.make_from_argv(sys.argv[1:])\n    main(blueprint.make())\n"
  },
  {
    "path": "tinker_cookbook/recipes/sl_loop.py",
    "content": "\"\"\"\nMinimal supervised fine-tuning script without abstractions.\nUses existing modules but with a simple, flat training loop.\n\"\"\"\n\nimport logging\nimport time\n\nimport chz\nimport datasets\nimport tinker\n\nfrom tinker_cookbook import checkpoint_utils, model_info, renderers\nfrom tinker_cookbook.supervised.common import compute_mean_nll\nfrom tinker_cookbook.supervised.data import conversation_to_datum\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.utils import ml_log\n\nlogger = logging.getLogger(__name__)\nlogging.getLogger(\"httpx\").setLevel(logging.WARN)\n\n\n@chz.chz\nclass Config:\n    base_url: str | None = None\n    log_path: str = \"/tmp/tinker-examples/sl-loop\"\n    model_name: str = \"meta-llama/Llama-3.1-8B\"\n    batch_size: int = 128\n    learning_rate: float = 1e-4\n    max_length: int = 32768\n    train_on_what: renderers.TrainOnWhat = renderers.TrainOnWhat.ALL_ASSISTANT_MESSAGES\n    lora_rank: int = 32\n    save_every: int = 20  # 0 = disabled\n    ttl_seconds: int | None = 604800  # 7 days\n\n\ndef main(config: Config):\n    # Setup logging\n    ml_logger = ml_log.setup_logging(\n        log_dir=config.log_path,\n        wandb_project=None,\n        wandb_name=None,\n        config=config,\n        do_configure_logging_module=True,\n    )\n\n    # Get tokenizer and renderer\n    tokenizer = get_tokenizer(config.model_name)\n    renderer_name = model_info.get_recommended_renderer_name(config.model_name)\n    renderer = renderers.get_renderer(renderer_name, tokenizer)\n    logger.info(f\"Using renderer: {renderer_name}\")\n\n    # Load No Robots dataset\n    logger.info(\"Loading dataset...\")\n    dataset = datasets.load_dataset(\"HuggingFaceH4/no_robots\")\n    assert isinstance(dataset, datasets.DatasetDict)\n    train_dataset = dataset[\"train\"]\n\n    # Drop the last incomplete batch (like PyTorch's drop_last=True) — a partial\n    # batch has different effective gradient magnitude, which can cause a training spike.\n    n_train_batches = len(train_dataset) // config.batch_size\n    n_dropped = len(train_dataset) % config.batch_size\n    if n_dropped:\n        logger.info(\n            f\"Dropping last {n_dropped} examples to keep batch size uniform at {config.batch_size}\"\n        )\n    logger.info(f\"Train batches: {n_train_batches}\")\n\n    # Setup training client\n    service_client = tinker.ServiceClient(base_url=config.base_url)\n\n    # Check for resuming\n    resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)\n    if resume_info:\n        training_client = service_client.create_training_client_from_state_with_optimizer(\n            resume_info.state_path\n        )\n        start_batch = resume_info.batch\n        logger.info(f\"Resuming from batch {start_batch}\")\n    else:\n        training_client = service_client.create_lora_training_client(\n            base_model=config.model_name, rank=config.lora_rank\n        )\n        start_batch = 0\n\n    # Training loop (single epoch)\n    logger.info(f\"Training for {n_train_batches} steps\")\n\n    # Shuffle dataset\n    train_dataset = train_dataset.shuffle(seed=0)\n\n    for batch_idx in range(start_batch, n_train_batches):\n        start_time = time.time()\n        step = batch_idx\n        metrics = {}\n\n        # Save checkpoint\n        if config.save_every > 0 and step % config.save_every == 0 and step > 0:\n            checkpoint_utils.save_checkpoint(\n                training_client=training_client,\n                name=f\"{step:06d}\",\n                log_path=config.log_path,\n                kind=\"state\",\n                loop_state={\"batch\": batch_idx},\n                ttl_seconds=config.ttl_seconds,\n            )\n\n        # Linear learning rate schedule\n        lr_mult = max(0.0, 1.0 - step / n_train_batches)\n        current_lr = config.learning_rate * lr_mult\n        adam_params = tinker.AdamParams(learning_rate=current_lr, beta1=0.9, beta2=0.95, eps=1e-8)\n\n        # Get training batch and convert to datums online\n        batch_start = batch_idx * config.batch_size\n        batch_end = min((batch_idx + 1) * config.batch_size, len(train_dataset))\n        batch_rows = train_dataset.select(range(batch_start, batch_end))\n\n        batch = [\n            conversation_to_datum(\n                row[\"messages\"],  # type: ignore\n                renderer,\n                config.max_length,\n                config.train_on_what,\n            )\n            for row in batch_rows\n        ]\n\n        # Training step\n        fwd_bwd_future = training_client.forward_backward(batch, loss_fn=\"cross_entropy\")\n        optim_step_future = training_client.optim_step(adam_params)\n\n        fwd_bwd_result = fwd_bwd_future.result()\n        optim_result = optim_step_future.result()\n\n        if optim_result.metrics:\n            metrics.update(optim_result.metrics)\n\n        # Compute train metrics\n        train_logprobs = [x[\"logprobs\"] for x in fwd_bwd_result.loss_fn_outputs]\n        train_weights = [d.loss_fn_inputs[\"weights\"] for d in batch]\n        train_nll = compute_mean_nll(train_logprobs, train_weights)\n\n        # Log metrics\n        metrics.update(\n            num_sequences=len(batch),\n            num_tokens=sum(d.model_input.length for d in batch),\n            learning_rate=current_lr,\n            train_mean_nll=train_nll,\n            progress=step / n_train_batches,\n            time_total=time.time() - start_time,\n        )\n        ml_logger.log_metrics(metrics=metrics, step=step)\n\n    # Save final checkpoint\n    checkpoint_utils.save_checkpoint(\n        training_client=training_client,\n        name=\"final\",\n        log_path=config.log_path,\n        kind=\"both\",\n        loop_state={\"batch\": n_train_batches},\n        ttl_seconds=None,\n    )\n\n    ml_logger.close()\n    logger.info(\"Training completed\")\n\n\nif __name__ == \"__main__\":\n    chz.nested_entrypoint(main)\n"
  },
  {
    "path": "tinker_cookbook/recipes/verifiers_rl/README.md",
    "content": "# RL Training with Tinker + Environments Hub (Verifiers)\n\n[Verifiers](https://github.com/primeintellect-ai/verifiers) is a library for creating RL environments for LLMs, including many community implementations featured on Prime Intellect's [Environments Hub](https://app.primeintellect.ai/dashboard/environments). This recipe allows all text-based environments from the Environments Hub to be used with Tinker for RL training.\n\nTo use this recipe, you need to have your chosen environment module (a self-contained Python package) installed in your project. You can install environments from the Environments Hub using the `prime` CLI:\n\n```bash\nuv tool install prime # or pipx install prime\nprime env install user/env-id # ex. prime env install primeintellect/reverse-text\n```\n\nExamples:\n- [primeintellect/reverse-text](https://app.primeintellect.ai/dashboard/environments/primeintellect/reverse-text)\n- [primeintellect/alphabet-sort](https://app.primeintellect.ai/dashboard/environments/primeintellect/alphabet-sort)\n- [primeintellect/math-python](https://app.primeintellect.ai/dashboard/environments/primeintellect/math-python)\n- [will/wordle](https://app.primeintellect.ai/dashboard/environments/will/wordle)\n\nYou can then run the recipe with the following command, where `vf_env_id` is the ID (just `env-id`) of the environment, and `vf_env_args` is an optional JSON string of arguments to pass when loading the environment.\n\n```bash\npython -m tinker_cookbook.recipes.verifiers_rl.train vf_env_id=env-id vf_env_args='{}' ...\n```\n\nThe reverse-text example as configured should climb from ~0.2 to ~0.35 in 32 steps.\n\nYou can also evaluate offline:\n\n```bash\npython -m tinker_cookbook.recipes.verifiers_rl.evaluate vf_env_id=env-id vf_env_args='{}' ...\n```\n\nThis recipe also includes a standalone `AsyncOpenAI`-compatible client implemented with Tinker, which can be adapted for other applications.\n\n**Potential footgun:**\n- Some Environments Hub implementations involve users writing their own `<think>` parsers (e.g. for use with reasoning RL starting on Instruct models). Despite being Instruct models, the Qwen3 models/tokenizers all use the same tokenizer chat template, which will strip any observed `<think>` sections automatically (which may be inadvertently penalized by reward functions). Users should either modify the renderer, tokenizer chat template, or environment module if observing issues with thinking sections from Qwen3 models.\n"
  },
  {
    "path": "tinker_cookbook/recipes/verifiers_rl/evaluate.py",
    "content": "from __future__ import annotations\n\nimport asyncio\nimport json\nimport time\n\nimport chz\nimport numpy as np\nimport tinker\nimport verifiers as vf\nfrom verifiers.utils.message_utils import messages_to_printable\n\nfrom tinker_cookbook import checkpoint_utils, model_info, renderers\nfrom tinker_cookbook.recipes.verifiers_rl.tinker_openai import TinkerAsyncOpenAIClient\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\ndef log_results(\n    results: vf.GenerateOutputs,\n    vf_env_id: str,\n    model_name: str,\n    num_examples: int,\n    rollouts_per_example: int,\n    time_s: float,\n):\n    print(f\"Evaluation completed in {time_s:.2f} seconds\")\n    print(\"--- Evaluation ---\")\n    print(f\"Environment: {vf_env_id}\")\n    print(f\"Model: {model_name}\")\n    print(f\"Examples: {num_examples}\")\n    print(f\"Rollouts per example: {rollouts_per_example}\")\n    print(\"--- Example ---\")\n    printable_prompts = [messages_to_printable(p) for p in results[\"prompt\"]]\n    printable_completions = [messages_to_printable(c) for c in results[\"completion\"]]\n    vf.print_prompt_completions_sample(\n        prompts=printable_prompts,\n        completions=printable_completions,\n        errors=[],  # Required argument added in verifiers 0.1.9\n        rewards=results[\"reward\"],\n        step=0,\n    )\n    print(\"--- All ---\")\n    print(\"Rewards:\")\n    print(\n        f\"reward: avg - {sum(results['reward']) / len(results['reward']):.3f}, std - {np.std(results['reward']):.3f}\"\n    )\n    r = rollouts_per_example\n    n = len(results[\"reward\"]) // r\n    for i in range(r):\n        # rounded to 3 decimal places\n        trials = [round(results[\"reward\"][(i * n) + j], 3) for j in range(n)]\n        out = f\"r{i + 1}: {trials}\"\n        print(out)\n    for k in results[\"metrics\"]:\n        v = results[\"metrics\"][k]\n        print(f\"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}\")\n        for i in range(r):\n            # rounded to 3 decimal places\n            trials = [round(v[(i * n) + j], 3) for j in range(n)]\n            out = f\"r{i + 1}: {trials}\"\n            print(out)\n\n\nasync def evaluate(\n    vf_env_id: str,\n    vf_env_args: dict,\n    model_name: str | None,\n    num_examples: int,\n    rollouts_per_example: int,\n    max_concurrent: int,\n    max_tokens: int,\n    temperature: float,\n    model_path: str | None = None,\n):\n    service = tinker.ServiceClient()\n\n    # If model_path is provided, get the base model from the training run\n    if model_path is not None:\n        rest_client = service.create_rest_client()\n        training_run = await rest_client.get_training_run_by_tinker_path_async(model_path)\n        if model_name:\n            if model_name != training_run.base_model:\n                raise ValueError(\n                    f\"Model name {model_name} does not match training run base model {training_run.base_model}\"\n                )\n        else:\n            model_name = training_run.base_model\n\n    if model_name is None:\n        raise ValueError(\"model_name or model_path must be provided\")\n\n    env = vf.load_environment(vf_env_id, **vf_env_args)\n    tokenizer = get_tokenizer(model_name)\n    renderer_name = None\n    if model_path is not None:\n        renderer_name = await checkpoint_utils.get_renderer_name_from_checkpoint_async(\n            service, model_path\n        )\n    if renderer_name is None:\n        renderer_name = model_info.get_recommended_renderer_name(model_name)\n    print(f\"Using renderer: {renderer_name}\")\n    renderer = renderers.get_renderer(renderer_name, tokenizer)\n\n    # Create sampling client from checkpoint path or base model\n    if model_path:\n        sampling = service.create_sampling_client(model_path=model_path, base_model=model_name)\n    else:\n        sampling = service.create_sampling_client(base_model=model_name)\n\n    client = TinkerAsyncOpenAIClient(sampling, renderer, tokenizer)\n    start_time = time.time()\n    results = env.evaluate_sync(\n        client=client,\n        model=model_name,\n        num_examples=num_examples,\n        rollouts_per_example=rollouts_per_example,\n        max_concurrent=max_concurrent,\n        sampling_args={\n            \"max_tokens\": max_tokens,\n            \"temperature\": temperature,\n        },\n    )\n    end_time = time.time()\n    log_results(\n        results,\n        vf_env_id,\n        model_name,\n        num_examples,\n        rollouts_per_example,\n        end_time - start_time,\n    )\n    return results\n\n\n@chz.chz\nclass CLIConfig:\n    model_name: str | None = None  # Base model name (auto-detected from checkpoint if not provided)\n    model_path: str | None = None  # Path to checkpoint (e.g., from checkpoints.jsonl sampler_path)\n    vf_env_id: str = \"reverse-text\"\n    vf_env_args: str | None = None  # JSON string\n    num_examples: int = 5\n    rollouts_per_example: int = 3\n    max_concurrent: int = 32\n    max_tokens: int = 1024\n    temperature: float = 1.0\n\n\nasync def cli_main(cfg: CLIConfig):\n    env_args = json.loads(cfg.vf_env_args) if cfg.vf_env_args else {}\n    return await evaluate(\n        vf_env_id=cfg.vf_env_id,\n        vf_env_args=env_args,\n        model_name=cfg.model_name,\n        num_examples=cfg.num_examples,\n        rollouts_per_example=cfg.rollouts_per_example,\n        max_concurrent=cfg.max_concurrent,\n        max_tokens=cfg.max_tokens,\n        temperature=cfg.temperature,\n        model_path=cfg.model_path,\n    )\n\n\nif __name__ == \"__main__\":\n    cfg = chz.entrypoint(CLIConfig)\n\n    asyncio.run(cli_main(cfg))\n"
  },
  {
    "path": "tinker_cookbook/recipes/verifiers_rl/tinker_openai.py",
    "content": "\"\"\"\nOpenAI-compatible client backed by Tinker sampling.\n\nImplements OpenAI client semantics for:\n- chat.completions.create(...)\n- completions.create(...)\n\nReturns OpenAI types (ChatCompletion / Completion) constructed from sampled tokens.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport time\nfrom typing import Any, Literal, overload\n\nimport tinker\nfrom openai import AsyncOpenAI\nfrom openai._streaming import AsyncStream\nfrom openai.resources.chat import AsyncChat as OpenAIAsyncChat\nfrom openai.resources.chat.completions import AsyncCompletions as OpenAIAsyncChatCompletions\nfrom openai.resources.completions import AsyncCompletions as OpenAIAsyncCompletions\nfrom openai.types.chat.chat_completion import ChatCompletion\nfrom openai.types.completion import Completion\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n\nclass TinkerAsyncOpenAIClient(AsyncOpenAI):\n    \"\"\"\n    OpenAI-compatible async client that routes calls to a Tinker SamplingClient.\n    \"\"\"\n\n    def __init__(\n        self,\n        sampling_client: tinker.SamplingClient,\n        renderer: renderers.Renderer,\n        tokenizer: Tokenizer,\n    ) -> None:\n        super().__init__(api_key=\"tinker\", base_url=\"http://localhost\")\n        self.sampling_client = sampling_client\n        self.renderer = renderer\n        self.tokenizer = tokenizer\n\n    def set_sampling_client(self, sampling_client: tinker.SamplingClient) -> None:\n        self.sampling_client = sampling_client\n\n    @property\n    def chat(self) -> OpenAIAsyncChat:\n        return TinkerAsyncChat(self)\n\n    @property\n    def completions(self) -> OpenAIAsyncCompletions:\n        return TinkerCompletions(self)\n\n\nclass TinkerChatCompletions(OpenAIAsyncChatCompletions):\n    def __init__(self, parent: TinkerAsyncOpenAIClient) -> None:\n        self._parent = parent\n\n    @overload\n    async def create(\n        self, *args: Any, stream: Literal[True], **kwargs: Any\n    ) -> AsyncStream[Any]: ...\n\n    @overload\n    async def create(\n        self, *args: Any, stream: Literal[False] = False, **kwargs: Any\n    ) -> ChatCompletion: ...\n\n    @overload\n    async def create(self, *args: Any, stream: bool, **kwargs: Any) -> ChatCompletion: ...\n\n    async def create(self, *args: Any, **kwargs: Any) -> ChatCompletion | AsyncStream[Any]:\n        model = kwargs.get(\"model\", \"tinker\")\n        messages = kwargs.get(\"messages\", [])\n        if kwargs.get(\"tools\"):\n            raise NotImplementedError(\"Tool calling is not yet supported by this model's renderer.\")\n        if kwargs.get(\"stream\", False):\n            raise ValueError(\"stream=True not supported by TinkerAsyncOpenAIClient\")\n        sampling_args = {k: v for k, v in kwargs.items() if k not in (\"model\", \"messages\", \"tools\")}\n\n        stop = sampling_args.get(\"stop\", self._parent.renderer.get_stop_sequences())\n        max_tokens = sampling_args.get(\"max_tokens\") or sampling_args.get(\"max_completion_tokens\")\n\n        model_input = self._parent.renderer.build_generation_prompt(messages)\n        prompt_token_ids: list[int] = model_input.to_ints()\n\n        sample = await self._parent.sampling_client.sample_async(\n            prompt=model_input,\n            num_samples=1,\n            sampling_params=tinker.SamplingParams(\n                temperature=float(sampling_args.get(\"temperature\", 1.0)),\n                max_tokens=int(max_tokens or 128),\n                top_p=float(sampling_args.get(\"top_p\", 1.0)),\n                top_k=int(sampling_args.get(\"top_k\", -1)),\n                stop=stop,\n            ),\n        )\n        seq = sample.sequences[0]\n        completion_token_ids: list[int] = seq.tokens\n        logprobs: list[float] = seq.logprobs or [0.0] * len(completion_token_ids)\n\n        assistant_message, parse_success = self._parent.renderer.parse_response(\n            completion_token_ids\n        )\n        finish_reason = \"stop\" if parse_success else \"length\"\n\n        # Convert list content to string for OpenAI compatibility\n        openai_content = renderers.format_content_as_string(assistant_message[\"content\"])\n\n        # Build OpenAI-compatible message\n        openai_message: dict[str, Any] = {\n            \"role\": \"assistant\",\n            \"content\": openai_content,\n        }\n        # Include tool_calls if present\n        if \"tool_calls\" in assistant_message:\n            openai_message[\"tool_calls\"] = [\n                {\n                    \"id\": tc.id or f\"call_{i}\",\n                    \"type\": \"function\",\n                    \"function\": {\"name\": tc.function.name, \"arguments\": tc.function.arguments},\n                }\n                for i, tc in enumerate(assistant_message[\"tool_calls\"])\n            ]\n\n        response_dict: dict[str, Any] = {\n            \"id\": \"tinker-chatcmpl\",\n            \"object\": \"chat.completion\",\n            \"created\": int(time.time()),\n            \"model\": model,\n            \"choices\": [\n                {\n                    \"index\": 0,\n                    \"message\": openai_message,\n                    \"finish_reason\": finish_reason,\n                    \"logprobs\": {\n                        \"content\": [\n                            {\"token\": f\"token_id:{tid}\", \"logprob\": lp, \"top_logprobs\": []}\n                            for tid, lp in zip(completion_token_ids, logprobs)\n                        ]\n                    },\n                }\n            ],\n            \"usage\": {\n                \"prompt_tokens\": len(prompt_token_ids),\n                \"completion_tokens\": len(completion_token_ids),\n                \"total_tokens\": len(prompt_token_ids) + len(completion_token_ids),\n            },\n        }\n        response = ChatCompletion.model_validate(response_dict)\n\n        object.__setattr__(response, \"prompt_token_ids\", prompt_token_ids)\n        object.__setattr__(response.choices[0], \"token_ids\", completion_token_ids)\n\n        return response\n\n\nclass TinkerCompletions(OpenAIAsyncCompletions):\n    def __init__(self, parent: TinkerAsyncOpenAIClient) -> None:\n        self._parent = parent\n\n    @overload\n    async def create(\n        self, *args: Any, stream: Literal[True], **kwargs: Any\n    ) -> AsyncStream[Completion]: ...\n\n    @overload\n    async def create(\n        self, *args: Any, stream: Literal[False] = False, **kwargs: Any\n    ) -> Completion: ...\n\n    @overload\n    async def create(\n        self, *args: Any, stream: bool, **kwargs: Any\n    ) -> Completion | AsyncStream[Completion]: ...\n\n    async def create(self, *args: Any, **kwargs: Any) -> Completion | AsyncStream[Completion]:\n        stream = bool(kwargs.get(\"stream\", False))\n        model = kwargs.get(\"model\", \"tinker\")\n        prompt = kwargs.get(\"prompt\", \"\")\n        sampling_args = {k: v for k, v in kwargs.items() if k not in (\"model\", \"prompt\")}\n\n        prompt_token_ids: list[int] = self._parent.tokenizer.encode(prompt, add_special_tokens=True)\n        model_input = tinker.ModelInput.from_ints(prompt_token_ids)\n\n        sample = await self._parent.sampling_client.sample_async(\n            prompt=model_input,\n            num_samples=1,\n            sampling_params=tinker.SamplingParams(\n                temperature=float(sampling_args.get(\"temperature\", 1.0)),\n                max_tokens=int(sampling_args.get(\"max_tokens\", 128)),\n                top_p=float(sampling_args.get(\"top_p\", 1.0)),\n                top_k=int(sampling_args.get(\"top_k\", -1)),\n            ),\n        )\n        seq = sample.sequences[0]\n        completion_token_ids: list[int] = seq.tokens\n        logprobs: list[float] = seq.logprobs or [0.0] * len(completion_token_ids)\n\n        text = self._parent.tokenizer.decode(completion_token_ids)\n        tokens_str = [f\"token_id:{tid}\" for tid in completion_token_ids]\n        response_dict: dict[str, Any] = {\n            \"id\": \"tinker-cmpl\",\n            \"object\": \"text_completion\",\n            \"created\": int(time.time()),\n            \"model\": model,\n            \"choices\": [\n                {\n                    \"index\": 0,\n                    \"text\": text,\n                    \"finish_reason\": \"stop\",\n                    \"logprobs\": {\n                        \"tokens\": tokens_str,\n                        \"token_logprobs\": logprobs,\n                    },\n                }\n            ],\n            \"usage\": {\n                \"prompt_tokens\": len(prompt_token_ids),\n                \"completion_tokens\": len(completion_token_ids),\n                \"total_tokens\": len(prompt_token_ids) + len(completion_token_ids),\n            },\n        }\n        response = Completion.model_validate(response_dict)\n\n        object.__setattr__(response.choices[0], \"prompt_token_ids\", prompt_token_ids)\n        object.__setattr__(response.choices[0], \"token_ids\", completion_token_ids)\n\n        if stream:\n            return TinkerAsyncCompletionStream(response)\n        return response\n\n\nclass TinkerAsyncChat(OpenAIAsyncChat):\n    def __init__(self, parent: TinkerAsyncOpenAIClient) -> None:\n        self._parent = parent\n\n    @property\n    def completions(self) -> OpenAIAsyncChatCompletions:\n        return TinkerChatCompletions(self._parent)\n\n\nclass TinkerAsyncCompletionStream(AsyncStream[Completion]):\n    def __init__(self, final: Completion) -> None:\n        self._final = final\n\n    def __aiter__(self):\n        self._done = True\n        return self\n\n    async def __anext__(self) -> Completion:\n        raise StopAsyncIteration\n\n    def __await__(self):\n        async def _await_final():\n            return self._final\n\n        return _await_final().__await__()\n\n    async def get_final_response(self) -> Completion:\n        return self._final\n"
  },
  {
    "path": "tinker_cookbook/recipes/verifiers_rl/train.py",
    "content": "from __future__ import annotations\n\nimport asyncio\nimport json\nimport logging\nfrom datetime import datetime\nfrom typing import Any, cast\n\nimport chz\nfrom verifiers.utils.async_utils import maybe_semaphore\n\nfrom tinker_cookbook import cli_utils, model_info, renderers\nfrom tinker_cookbook.completers import TinkerTokenCompleter, TokenCompleter\nfrom tinker_cookbook.recipes.verifiers_rl.tinker_openai import TinkerAsyncOpenAIClient\nfrom tinker_cookbook.recipes.verifiers_rl.verifiers_env import (\n    VerifiersEnvGroupBuilder,\n    VerifiersRLDatasetBuilder,\n    convert_states_to_trajectory_group,\n)\nfrom tinker_cookbook.rl import train\nfrom tinker_cookbook.rl.types import EnvGroupBuilder, TrajectoryGroup\nfrom tinker_cookbook.tokenizer_utils import Tokenizer, get_tokenizer\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass CLIConfig:\n    # model configuration\n    model_name: str = \"Qwen/Qwen3-4B-Instruct-2507\"\n    lora_rank: int = 32\n\n    # environment configuration\n    vf_env_id: str = \"reverse-text\"\n    vf_env_args: str | None = None  # JSON string\n    dataset_n: int = -1\n    dataset_seed: int | None = None\n\n    # training hyperparameters\n    group_size: int = 8\n    groups_per_batch: int = 32\n    num_substeps: int = 1\n    learning_rate: float = 1e-5\n    max_tokens: int = 512\n    temperature: float = 1.0\n    kl_penalty_coef: float = 0.0\n    max_concurrent_generation: int = -1\n    max_concurrent_scoring: int = -1\n\n    # logging configuration\n    eval_every: int = 0\n    save_every: int = 10\n    log_path: str | None = None\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n    behavior_if_log_dir_exists: cli_utils.LogdirBehavior = \"ask\"\n\n    max_steps: int | None = None\n\n\nasync def cli_main(cli_config: CLIConfig, env: Any | None):\n    model_name_short = cli_config.model_name.replace(\"/\", \"-\")\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n    run_name = (\n        f\"verifiers_rl_{model_name_short}_gp{cli_config.groups_per_batch}_gs{cli_config.group_size}\"\n        f\"_lr{cli_config.learning_rate}_rank{cli_config.lora_rank}_{date_and_time}\"\n    )\n\n    log_path = cli_config.log_path or f\"/tmp/tinker-examples/verifiers_rl/{run_name}\"\n    cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)\n\n    env_args = json.loads(cli_config.vf_env_args) if cli_config.vf_env_args else {}\n\n    shared_client: TinkerAsyncOpenAIClient | None = None\n    shared_renderer: renderers.Renderer | None = None\n    local_tokenizer: Tokenizer | None = None\n\n    async def custom_do_group_rollout(\n        builder: EnvGroupBuilder, policy: TokenCompleter\n    ) -> TrajectoryGroup:\n        nonlocal shared_client, shared_renderer, local_tokenizer\n\n        # initialize tokenizer and renderer lazily\n        if local_tokenizer is None:\n            local_tokenizer = get_tokenizer(cli_config.model_name)\n        if shared_renderer is None:\n            renderer_name = model_info.get_recommended_renderer_name(cli_config.model_name)\n            shared_renderer = renderers.get_renderer(renderer_name, local_tokenizer)\n\n        sampling_client = cast(TinkerTokenCompleter, policy).sampling_client\n        if shared_client is None:\n            shared_client = TinkerAsyncOpenAIClient(\n                sampling_client, shared_renderer, local_tokenizer\n            )\n        else:\n            shared_client.set_sampling_client(sampling_client)\n\n        vf_builder = cast(VerifiersEnvGroupBuilder, builder)\n        rollout_inputs = vf_builder.get_rollout_inputs(cli_config.group_size)\n\n        gen_sem = await maybe_semaphore(cli_config.max_concurrent_generation)\n        score_sem = await maybe_semaphore(cli_config.max_concurrent_scoring)\n\n        states = await vf_builder.vf_env.run_group(\n            group_inputs=rollout_inputs,\n            client=shared_client,\n            model=\"tinker\",\n            gen_sampling_args={\n                \"max_tokens\": cli_config.max_tokens,\n                \"temperature\": cli_config.temperature,\n            },\n            gen_sem=gen_sem,\n            score_sem=score_sem,\n        )\n\n        return convert_states_to_trajectory_group(states)\n\n    # override do_group_rollout function inside rl.train\n    train.do_group_rollout = custom_do_group_rollout\n\n    dataset_builder = VerifiersRLDatasetBuilder(\n        vf_env_id=cli_config.vf_env_id,\n        vf_env_args=env_args,\n        groups_per_batch=cli_config.groups_per_batch,\n        dataset_n=cli_config.dataset_n,\n        dataset_seed=cli_config.dataset_seed,\n    )\n\n    cfg = train.Config(\n        learning_rate=cli_config.learning_rate,\n        dataset_builder=dataset_builder,\n        model_name=cli_config.model_name,\n        max_tokens=cli_config.max_tokens,\n        temperature=cli_config.temperature,\n        lora_rank=cli_config.lora_rank,\n        kl_penalty_coef=cli_config.kl_penalty_coef,\n        num_substeps=cli_config.num_substeps,\n        wandb_project=cli_config.wandb_project,\n        wandb_name=cli_config.wandb_name or run_name,\n        log_path=log_path,\n        eval_every=cli_config.eval_every,\n        save_every=cli_config.save_every,\n        stream_minibatch_config=None,\n        max_steps=cli_config.max_steps,\n    )\n\n    await train.main(cfg)\n\n\nif __name__ == \"__main__\":\n    cli_config = chz.entrypoint(CLIConfig)\n    asyncio.run(cli_main(cli_config, None))\n"
  },
  {
    "path": "tinker_cookbook/recipes/verifiers_rl/verifiers_env.py",
    "content": "from __future__ import annotations\n\nfrom collections.abc import Sequence\nfrom contextvars import ContextVar\n\nimport chz\nimport tinker\nimport verifiers as vf\n\nfrom tinker_cookbook.completers import TokensWithLogprobs\nfrom tinker_cookbook.rl.types import (\n    EnvGroupBuilder,\n    RLDataset,\n    RLDatasetBuilder,\n    Trajectory,\n    TrajectoryGroup,\n    Transition,\n)\n\n_vf_env_ctx: ContextVar[vf.Environment | None] = ContextVar(\"vf_env\", default=None)\n\n\ndef set_vf_env(env: vf.Environment) -> None:\n    \"\"\"Set the verifiers environment for the current context.\"\"\"\n    _vf_env_ctx.set(env)\n\n\ndef get_vf_env() -> vf.Environment | None:\n    \"\"\"Get the verifiers environment from the current context.\"\"\"\n    return _vf_env_ctx.get()\n\n\ndef convert_states_to_trajectory_group(states: list[vf.State]) -> TrajectoryGroup:\n    \"\"\"Convert verifiers States to tinker TrajectoryGroup.\"\"\"\n    trajectories_G: list[Trajectory] = []\n    final_rewards_G: list[float] = []\n    metrics_G: list[dict[str, float | int]] = []\n\n    for state in states:\n        transitions: list[Transition] = []\n        trajectory_steps = state.get(\"trajectory\", [])\n\n        for i, step in enumerate(trajectory_steps):\n            tokens_data = step.get(\"tokens\")\n            if tokens_data is not None:\n                prompt_ids = tokens_data.get(\"prompt_ids\", [])\n                ob = tinker.ModelInput.from_ints(prompt_ids)\n                completion_ids = tokens_data.get(\"completion_ids\", [])\n                completion_logprobs = tokens_data.get(\"completion_logprobs\", [])\n                ac = TokensWithLogprobs(\n                    tokens=completion_ids,\n                    maybe_logprobs=completion_logprobs,\n                )\n            else:\n                ob = tinker.ModelInput.empty()\n                ac = TokensWithLogprobs(tokens=[], maybe_logprobs=[])\n\n            is_last = i == len(trajectory_steps) - 1\n            transition = Transition(\n                ob=ob,\n                ac=ac,\n                reward=0.0,\n                episode_done=is_last,\n                metrics={},\n            )\n            transitions.append(transition)\n\n        trajectory = Trajectory(transitions=transitions, final_ob=tinker.ModelInput.empty())\n        trajectories_G.append(trajectory)\n        final_rewards_G.append(state.get(\"reward\") or 0.0)\n        metrics_G.append(state.get(\"metrics\") or {})\n\n    return TrajectoryGroup(\n        trajectories_G=trajectories_G,\n        final_rewards_G=final_rewards_G,\n        metrics_G=metrics_G,\n    )\n\n\nclass VerifiersRLDataset(RLDataset):\n    def __init__(\n        self,\n        rows: list[dict],\n        vf_env: vf.Environment,\n        groups_per_batch: int,\n    ):\n        self.rows = rows\n        self.vf_env = vf_env\n        self.groups_per_batch = groups_per_batch\n\n    def __len__(self) -> int:\n        return (len(self.rows) + self.groups_per_batch - 1) // self.groups_per_batch\n\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        start = index * self.groups_per_batch\n        end = min(len(self.rows), start + self.groups_per_batch)\n        builders: list[EnvGroupBuilder] = []\n        for j in range(start, end):\n            row = self.rows[j]\n            builders.append(\n                VerifiersEnvGroupBuilder(\n                    vf_env=self.vf_env,\n                    prompt=row[\"prompt\"],\n                    example_id=row[\"example_id\"],\n                    task=row[\"task\"],\n                    answer=row.get(\"answer\", \"\"),\n                    info=row.get(\"info\", {}),\n                )\n            )\n        return builders\n\n\n@chz.chz\nclass VerifiersRLDatasetBuilder(RLDatasetBuilder):\n    vf_env_id: str\n    vf_env_args: dict = chz.field(default_factory=dict)\n    groups_per_batch: int = 32\n    dataset_n: int = -1\n    dataset_seed: int | None = None\n\n    async def __call__(self) -> tuple[RLDataset, RLDataset | None]:\n        vf_env = get_vf_env()\n        if vf_env is None:\n            vf_env = vf.load_environment(self.vf_env_id, **self.vf_env_args)\n            set_vf_env(vf_env)\n        ds = vf_env.get_dataset(n=self.dataset_n, seed=self.dataset_seed)\n        rows = [\n            {\n                \"prompt\": ds[\"prompt\"][i],\n                \"example_id\": ds[\"example_id\"][i],\n                \"task\": ds[\"task\"][i],\n                **({\"answer\": ds[\"answer\"][i]} if \"answer\" in ds.column_names else {}),\n                **({\"info\": ds[\"info\"][i]} if \"info\" in ds.column_names else {}),\n            }\n            for i in range(len(ds))\n        ]\n        return VerifiersRLDataset(rows, vf_env, self.groups_per_batch), None\n\n\nclass VerifiersEnvGroupBuilder(EnvGroupBuilder):\n    \"\"\"EnvGroupBuilder for the verifiers library integration.\n\n    Pickle support: ``vf.Environment`` is not pickleable. On deserialization,\n    it is recovered from the ``_vf_env_ctx`` context variable (set via\n    ``set_vf_env()``). Raises ``RuntimeError`` if the context variable is not\n    set — this is expected in cross-process scenarios since the verifiers\n    integration currently requires single-process execution (the\n    ``custom_do_group_rollout`` in train.py is a closure over shared state).\n    \"\"\"\n\n    def __init__(\n        self,\n        vf_env: vf.Environment,\n        prompt: vf.Messages,\n        example_id: int,\n        task: str,\n        answer: str = \"\",\n        info: dict | None = None,\n    ):\n        self.vf_env = vf_env\n        self.prompt = prompt\n        self.example_id = example_id\n        self.task = task\n        self.answer = answer\n        self.info = info or {}\n\n    def __getstate__(self) -> dict:\n        \"\"\"Exclude non-pickleable vf.Environment from pickle state.\"\"\"\n        state = self.__dict__.copy()\n        state[\"vf_env\"] = None\n        return state\n\n    def __setstate__(self, state: dict) -> None:\n        \"\"\"Restore vf.Environment from the context variable on unpickle.\"\"\"\n        vf_env = state.pop(\"vf_env\", None) or get_vf_env()\n        if vf_env is None:\n            raise RuntimeError(\n                \"VerifiersEnvGroupBuilder unpickled without a vf.Environment. \"\n                \"In cross-process scenarios (ProcessPoolExecutor, Ray), the worker \"\n                \"process must call set_vf_env(vf.load_environment(...)) before \"\n                \"unpickling builders. See verifiers_rl/train.py for reference.\"\n            )\n        self.vf_env = vf_env\n        self.prompt = state[\"prompt\"]\n        self.example_id = state[\"example_id\"]\n        self.task = state[\"task\"]\n        self.answer = state[\"answer\"]\n        self.info = state[\"info\"]\n\n    def get_rollout_inputs(self, group_size: int) -> list[vf.RolloutInput]:\n        return [\n            vf.RolloutInput(\n                prompt=self.prompt,\n                answer=self.answer,\n                task=self.task,\n                info=self.info,\n                example_id=self.example_id,\n            )\n            for _ in range(group_size)\n        ]\n\n    async def make_envs(self):\n        return []  # unused when using custom_do_group_rollout\n\n    def logging_tags(self) -> list[str]:\n        return [self.task] if self.task else []\n"
  },
  {
    "path": "tinker_cookbook/recipes/verifiers_rl/verifiers_pickle_test.py",
    "content": "\"\"\"Tests for picklability of VerifiersEnvGroupBuilder.\"\"\"\n\nimport pytest\n\ntry:\n    import verifiers as _verifiers  # noqa: F401\n\n    _has_verifiers = True\nexcept ImportError:\n    _has_verifiers = False\n\n\n@pytest.mark.skipif(not _has_verifiers, reason=\"verifiers not installed\")\nclass TestVerifiersEnvGroupBuilderPickle:\n    def test_pickle_excludes_vf_env(self) -> None:\n        \"\"\"VerifiersEnvGroupBuilder excludes vf_env from pickle state.\"\"\"\n        from unittest.mock import MagicMock\n\n        from tinker_cookbook.recipes.verifiers_rl.verifiers_env import VerifiersEnvGroupBuilder\n\n        builder = VerifiersEnvGroupBuilder(\n            vf_env=MagicMock(),\n            prompt=[{\"role\": \"user\", \"content\": \"What is 2+2?\"}],\n            example_id=42,\n            task=\"arithmetic\",\n            answer=\"4\",\n        )\n        state = builder.__getstate__()\n        assert state[\"vf_env\"] is None\n        assert state[\"prompt\"] == [{\"role\": \"user\", \"content\": \"What is 2+2?\"}]\n        assert state[\"example_id\"] == 42\n        assert state[\"task\"] == \"arithmetic\"\n        assert state[\"answer\"] == \"4\"\n"
  },
  {
    "path": "tinker_cookbook/recipes/vlm_classifier/README.md",
    "content": "# Supervised Learning\n\n## VLM Image Classification\n\nThis recipe will teach you how to train an image classifier powered by vision-language models on `tinker`.\n\n```bash\npython -m tinker_cookbook.recipes.vlm_classifier.train \\\n    experiment_dir=./vlm_classifier \\\n    wandb_project=vlm-classifier \\\n    dataset=caltech101 \\\n    renderer_name=qwen3_vl \\\n    model_name=Qwen/Qwen3-VL-30B-A3B-Instruct\n```\n\nCurrently, the qwen series of VLMs are supported. Running the above script after installing tinker-cookbook will fine-tune `Qwen/Qwen3-VL-30B-A3B-Instruct` on the `caltech101` as an example.\n\n### Evaluation\n\nOnce trained, you can evaluate the class predictions from your VLM as follows:\n\n```bash\npython -m tinker_cookbook.recipes.vlm_classifier.eval \\\n    dataset=caltech101 \\\n    model_path=$YOUR_MODEL_PATH \\\n    model_name=Qwen/Qwen3-VL-30B-A3B-Instruct \\\n    renderer_name=qwen3_vl\n```\n\nThis will print the test accuracy of your model.\n\n### Custom Datasets\n\nYou can add custom datasets by creating a custom `SupervisedDatasetBuilder` in `tinker_cookbook.recipes.vlm_classifier.data` if your dataset is available for download on Hugging Face, and has a column with your image, and a column with the image labels (note, you must also define a `ClassLabel` for mapping integer labels to a human-readable class name).\n\nFor more general datasets, you can subclass the base `ClassifierDataset` to load arbitrary image classification datasets in the provided classifier tooling.\n\n### Custom Evaluators\n\nWe provide a suite of evaluators in `tinker_cookbook.recipes.vlm_classifier.eval` for sampling from VLMs, parsing the predicted class name from the response, and computing evaluation metrics.\n\nTo define a custom evaluator for a new dataset, you can create a new `EvaluatorBuilder` if your dataset is available on Hugging Face, or you can subclass `ClassifierEvaluator` to add an arbitrary custom dataset and parsing strategy.\n"
  },
  {
    "path": "tinker_cookbook/recipes/vlm_classifier/data.py",
    "content": "\"\"\"\nDatasets for supervised learning (SFT) that use chat-formatted data, which we\nconvert to tokens using a Renderer.\n\"\"\"\n\nimport io\nimport logging\nimport math\nimport random\nfrom collections import defaultdict\nfrom typing import Any, cast\n\nimport chz\nimport datasets\nimport tinker\nimport torch\nfrom PIL import Image\n\nfrom tinker_cookbook.image_processing_utils import get_image_processor, resize_image\nfrom tinker_cookbook.renderers import (\n    ContentPart,\n    ImagePart,\n    Message,\n    TextPart,\n    TrainOnWhat,\n    get_renderer,\n)\nfrom tinker_cookbook.supervised.common import datum_from_model_input_weights\nfrom tinker_cookbook.supervised.types import SupervisedDataset, SupervisedDatasetBuilder\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass ClassifierDatasetConfig:\n    \"\"\"\n    Configuration for a classification dataset.\n    \"\"\"\n\n    dataset: str\n    dataset_split: str\n\n    image_column_name: str = \"image\"\n    label_column_name: str = \"label\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    num_repeats: float = 1\n    batch_size: int = 32\n    max_length: int = 8192\n\n    train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE\n\n    # If set, sample only this many examples per class (for few-shot experiments)\n    examples_per_class: int | None = None\n    subset_seed: int = 0\n\n    max_image_size: int = 480\n    hflip_probability: float = 0.5\n\n\nclass ClassifierDataset(SupervisedDataset):\n    def __init__(self, config: ClassifierDatasetConfig):\n        \"\"\"\n        Construct a VLM classifier dataset with the provided data config.\n        \"\"\"\n\n        self.config = config\n\n        tokenizer = get_tokenizer(self.config.model_name_for_tokenizer)\n        image_processor = get_image_processor(self.config.model_name_for_tokenizer)\n\n        self.renderer = get_renderer(\n            name=self.config.renderer_name, tokenizer=tokenizer, image_processor=image_processor\n        )\n\n        dataset = datasets.load_dataset(self.config.dataset)\n        dataset = cast(datasets.DatasetDict, dataset)\n        self.dataset = dataset[self.config.dataset_split]\n\n        # If examples_per_class is set, sample N examples per class for few-shot setting\n        if self.config.examples_per_class is not None:\n            self.dataset = self._sample_per_class(self.dataset)\n\n        self.class_labels = self.dataset.features[self.config.label_column_name]\n        self.shuffled_indices = self.get_shuffled_indices()\n\n    def get_shuffled_indices(self, seed: int = 0) -> list[int]:\n        \"\"\"\n        Get a shuffled set of dataset indices with a target number of num_repeats.\n        \"\"\"\n\n        max_repeat = int(math.ceil(self.config.num_repeats))\n        max_examples = int(math.ceil(self.config.num_repeats * len(self.dataset)))\n\n        random_gen = random.Random(seed)\n        shuffled_indices: list[int] = []\n\n        for _ in range(max_repeat):\n            dataset_indices = list(range(len(self.dataset)))\n            random_gen.shuffle(dataset_indices)\n            shuffled_indices.extend(dataset_indices)\n\n        return shuffled_indices[:max_examples]\n\n    def _sample_per_class(self, dataset: datasets.Dataset) -> datasets.Dataset:\n        \"\"\"\n        Sample up to N examples per class from the dataset for few-shot experiments.\n        Uses self.config.examples_per_class, label_column_name, and subset_seed.\n        \"\"\"\n        rng = random.Random(self.config.subset_seed)\n\n        # Group indices by class label\n        class_indices: dict[int, list[int]] = defaultdict(list)\n        for idx, label in enumerate(dataset[self.config.label_column_name]):\n            class_indices[label].append(idx)\n\n        # Shuffle and sample up to examples_per_class from each class\n        selected_indices: list[int] = []\n        for label in sorted(class_indices.keys()):\n            indices = class_indices[label]\n            rng.shuffle(indices)\n\n            selected_indices.extend(indices[: self.config.examples_per_class])\n\n        logger.info(\n            f\"Sampled {len(selected_indices)} examples \"\n            f\"({self.config.examples_per_class} per class, {len(class_indices)} classes)\"\n        )\n\n        return dataset.select(selected_indices)\n\n    def get_class_name(self, label: str) -> str:\n        \"\"\"\n        Helper function to clean up the original class name.\n        \"\"\"\n\n        return label.replace(\"_\", \" \").replace(\".\", \" \").replace(\"-\", \" \").lower()\n\n    def build_supervised_example(\n        self,\n        example: dict[str, Any],\n    ) -> tuple[tinker.ModelInput, torch.Tensor]:\n        \"\"\"\n        Generate an input to prompt the model.\n        \"\"\"\n\n        class_label = example[self.config.label_column_name]\n        class_label_name = self.get_class_name(self.class_labels.int2str(class_label))\n\n        image = example[self.config.image_column_name]\n        pil_image: Image.Image | None = None\n\n        if isinstance(image, dict) and \"bytes\" in image:\n            pil_image = Image.open(io.BytesIO(image[\"bytes\"]))\n\n        elif isinstance(image, Image.Image):\n            pil_image = cast(Image.Image, image)\n\n        # If the dataset cannot be loaded\n        if pil_image is None:\n            raise AssertionError(f\"Unable to interpret {image} as an image\")\n\n        pil_image = resize_image(image=pil_image, max_size=self.config.max_image_size)\n\n        # horizontal flip 50% of the time\n        if random.random() < self.config.hflip_probability:\n            pil_image = pil_image.transpose(Image.Transpose.FLIP_LEFT_RIGHT)\n\n        user_parts: list[ContentPart] = [\n            ImagePart(type=\"image\", image=pil_image),\n            TextPart(type=\"text\", text=\"What is the name of the subject in this photo?\"),\n        ]\n\n        assistant_parts: list[ContentPart] = [\n            TextPart(type=\"text\", text=f\"The subject in this photo is: {class_label_name}\\n\"),\n        ]\n\n        messages = [\n            Message(role=\"user\", content=user_parts),\n            Message(role=\"assistant\", content=assistant_parts),\n        ]\n\n        return self.renderer.build_supervised_example(\n            messages=messages,\n            train_on_what=self.config.train_on_what,\n        )\n\n    def get_batch(self, index: int) -> list[tinker.Datum]:\n        \"\"\"\n        Load a batch of training examples.\n        \"\"\"\n\n        return [\n            datum_from_model_input_weights(\n                *self.build_supervised_example(self.dataset[self.shuffled_indices[idx]]),\n                max_length=self.config.max_length,\n            )\n            for idx in range(\n                self.config.batch_size * index,\n                min(self.config.batch_size * (index + 1), len(self.shuffled_indices)),\n            )\n        ]\n\n    def __len__(self) -> int:\n        \"\"\"\n        Number of batches in the dataloader\n        \"\"\"\n\n        return int(math.ceil(len(self.shuffled_indices) / self.config.batch_size))\n\n    def set_epoch(self, seed: int = 0):\n        \"\"\"\n        Set the epoch for shuffling the dataloader.\n        \"\"\"\n\n        self.shuffled_indices = self.get_shuffled_indices(seed=seed)\n\n\n@chz.chz\nclass Caltech101DatasetBuilder(SupervisedDatasetBuilder):\n    \"\"\"\n    Configuration for a classification dataset.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    num_repeats: float = 1\n    batch_size: int = 32\n    max_length: int = 8192\n\n    train_on_what: TrainOnWhat | None = None\n\n    # If set, sample only this many examples per class (for few-shot experiments)\n    examples_per_class: int | None = None\n    subset_seed: int = 0\n\n    max_image_size: int = 480\n\n    run_nll_evaluator: bool = False\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        default_train_on_what = self.train_on_what or TrainOnWhat.LAST_ASSISTANT_MESSAGE\n\n        train_config = ClassifierDatasetConfig(\n            dataset=\"dpdl-benchmark/caltech101\",\n            dataset_split=\"train\",\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            num_repeats=self.num_repeats,\n            batch_size=self.batch_size,\n            max_length=self.max_length,\n            train_on_what=default_train_on_what,\n            examples_per_class=self.examples_per_class,\n            subset_seed=self.subset_seed,\n            max_image_size=self.max_image_size,\n            hflip_probability=0.5,\n        )\n\n        test_config = ClassifierDatasetConfig(\n            dataset=\"dpdl-benchmark/caltech101\",\n            dataset_split=\"test\",\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            batch_size=self.batch_size,\n            max_length=self.max_length,\n            train_on_what=default_train_on_what,\n            max_image_size=self.max_image_size,\n            hflip_probability=0.0,  # No augmentation for test set\n            # Note: test set uses full data, no few-shot sampling\n        )\n\n        train_dataset = ClassifierDataset(train_config)\n\n        if not self.run_nll_evaluator:\n            return train_dataset, None\n\n        return train_dataset, ClassifierDataset(test_config)\n\n\n@chz.chz\nclass Flowers102DatasetBuilder(SupervisedDatasetBuilder):\n    \"\"\"\n    Configuration for a classification dataset.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    num_repeats: float = 1\n    batch_size: int = 32\n    max_length: int = 8192\n\n    train_on_what: TrainOnWhat | None = None\n\n    # If set, sample only this many examples per class (for few-shot experiments)\n    examples_per_class: int | None = None\n    subset_seed: int = 0\n\n    max_image_size: int = 480\n\n    run_nll_evaluator: bool = False\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        default_train_on_what = self.train_on_what or TrainOnWhat.LAST_ASSISTANT_MESSAGE\n\n        train_config = ClassifierDatasetConfig(\n            dataset=\"dpdl-benchmark/oxford_flowers102\",\n            dataset_split=\"train\",\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            num_repeats=self.num_repeats,\n            batch_size=self.batch_size,\n            max_length=self.max_length,\n            train_on_what=default_train_on_what,\n            examples_per_class=self.examples_per_class,\n            subset_seed=self.subset_seed,\n            max_image_size=self.max_image_size,\n            hflip_probability=0.5,\n        )\n\n        test_config = ClassifierDatasetConfig(\n            dataset=\"dpdl-benchmark/oxford_flowers102\",\n            dataset_split=\"test\",\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            batch_size=self.batch_size,\n            max_length=self.max_length,\n            train_on_what=default_train_on_what,\n            max_image_size=self.max_image_size,\n            hflip_probability=0.0,\n            # Note: test set uses full data, no few-shot sampling\n        )\n\n        train_dataset = ClassifierDataset(train_config)\n\n        if not self.run_nll_evaluator:\n            return train_dataset, None\n\n        return train_dataset, ClassifierDataset(test_config)\n\n\n@chz.chz\nclass OxfordPetsDatasetBuilder(SupervisedDatasetBuilder):\n    \"\"\"\n    Configuration for a classification dataset.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    num_repeats: float = 1\n    batch_size: int = 32\n    max_length: int = 8192\n\n    train_on_what: TrainOnWhat | None = None\n\n    # If set, sample only this many examples per class (for few-shot experiments)\n    examples_per_class: int | None = None\n    subset_seed: int = 0\n\n    max_image_size: int = 480\n\n    run_nll_evaluator: bool = False\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        default_train_on_what = self.train_on_what or TrainOnWhat.LAST_ASSISTANT_MESSAGE\n\n        train_config = ClassifierDatasetConfig(\n            dataset=\"dpdl-benchmark/oxford_iiit_pet\",\n            dataset_split=\"train\",\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            num_repeats=self.num_repeats,\n            batch_size=self.batch_size,\n            max_length=self.max_length,\n            train_on_what=default_train_on_what,\n            examples_per_class=self.examples_per_class,\n            subset_seed=self.subset_seed,\n            max_image_size=self.max_image_size,\n            hflip_probability=0.5,\n        )\n\n        test_config = ClassifierDatasetConfig(\n            dataset=\"dpdl-benchmark/oxford_iiit_pet\",\n            dataset_split=\"test\",\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            batch_size=self.batch_size,\n            max_length=self.max_length,\n            train_on_what=default_train_on_what,\n            max_image_size=self.max_image_size,\n            hflip_probability=0.0,\n            # Note: test set uses full data, no few-shot sampling\n        )\n\n        train_dataset = ClassifierDataset(train_config)\n\n        if not self.run_nll_evaluator:\n            return train_dataset, None\n\n        return train_dataset, ClassifierDataset(test_config)\n\n\n@chz.chz\nclass StanfordCarsDatasetBuilder(SupervisedDatasetBuilder):\n    \"\"\"\n    Configuration for a classification dataset.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    num_repeats: float = 1\n    batch_size: int = 32\n    max_length: int = 8192\n\n    train_on_what: TrainOnWhat | None = None\n\n    # If set, sample only this many examples per class (for few-shot experiments)\n    examples_per_class: int | None = None\n    subset_seed: int = 0\n\n    max_image_size: int = 480\n\n    run_nll_evaluator: bool = False\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        default_train_on_what = self.train_on_what or TrainOnWhat.LAST_ASSISTANT_MESSAGE\n\n        train_config = ClassifierDatasetConfig(\n            dataset=\"tanganke/stanford_cars\",\n            dataset_split=\"train\",\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            num_repeats=self.num_repeats,\n            batch_size=self.batch_size,\n            max_length=self.max_length,\n            train_on_what=default_train_on_what,\n            examples_per_class=self.examples_per_class,\n            subset_seed=self.subset_seed,\n            max_image_size=self.max_image_size,\n            hflip_probability=0.5,\n        )\n\n        test_config = ClassifierDatasetConfig(\n            dataset=\"tanganke/stanford_cars\",\n            dataset_split=\"test\",\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            batch_size=self.batch_size,\n            max_length=self.max_length,\n            train_on_what=default_train_on_what,\n            max_image_size=self.max_image_size,\n            hflip_probability=0.0,\n            # Note: test set uses full data, no few-shot sampling\n        )\n\n        train_dataset = ClassifierDataset(train_config)\n\n        if not self.run_nll_evaluator:\n            return train_dataset, None\n\n        return train_dataset, ClassifierDataset(test_config)\n\n\nDATASETS = {\n    \"caltech101\": Caltech101DatasetBuilder,\n    \"flowers102\": Flowers102DatasetBuilder,\n    \"pets\": OxfordPetsDatasetBuilder,\n    \"cars\": StanfordCarsDatasetBuilder,\n}\n\n\ndef get_dataset_builder(\n    dataset: str,\n    model_name_for_tokenizer: str,\n    renderer_name: str,\n    num_repeats: float = 1,\n    batch_size: int = 32,\n    max_length: int = 8192,\n    train_on_what: TrainOnWhat | None = None,\n    examples_per_class: int | None = None,\n    subset_seed: int = 0,\n    max_image_size: int = 480,\n    run_nll_evaluator: bool = False,\n) -> SupervisedDatasetBuilder:\n    \"\"\"\n    Create a training and test dataset for a vlm classifier.\n\n    Args:\n        examples_per_class: If set, sample only this many examples per class\n            from the training set (for few-shot experiments). Test set is\n            unaffected.\n        subset_seed: Seed for shuffling before selecting the few-shot subset.\n        max_image_size: Maximum size for the longest side of images. Images\n            larger than this will be resized while preserving aspect ratio.\n    \"\"\"\n\n    return DATASETS[dataset](\n        model_name_for_tokenizer=model_name_for_tokenizer,\n        renderer_name=renderer_name,\n        num_repeats=num_repeats,\n        batch_size=batch_size,\n        max_length=max_length,\n        train_on_what=train_on_what,\n        examples_per_class=examples_per_class,\n        subset_seed=subset_seed,\n        max_image_size=max_image_size,\n        run_nll_evaluator=run_nll_evaluator,\n    )\n"
  },
  {
    "path": "tinker_cookbook/recipes/vlm_classifier/eval.py",
    "content": "import asyncio\nimport io\nimport logging\nfrom typing import Any, TypedDict, cast\n\nimport chz\nimport datasets\nimport numpy as np\nimport tinker\nfrom PIL import Image\nfrom tinker import types\n\nfrom tinker_cookbook import checkpoint_utils, model_info, renderers\nfrom tinker_cookbook.eval.evaluators import EvaluatorBuilder, SamplingClientEvaluator\nfrom tinker_cookbook.image_processing_utils import get_image_processor, resize_image\nfrom tinker_cookbook.renderers import ImagePart, Message, TextPart, get_text_content\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.utils.misc_utils import timed\n\n# Set up logger\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass ClassifierEvaluatorConfig:\n    \"\"\"\n    Configuration for classifier evaluation.\n    \"\"\"\n\n    dataset: str\n    dataset_split: str\n\n    image_column_name: str = \"image\"\n    label_column_name: str = \"label\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int | None = None\n    max_parallel_tasks: int = 128\n\n    max_image_size: int = 480\n\n\nclass ClassifierOutput(TypedDict):\n    \"\"\"\n    Parsed output from an image classification model.\n    \"\"\"\n\n    predicted_class_name: str\n\n\nclass ClassifierEvaluator(SamplingClientEvaluator):\n    \"\"\"\n    Evaluator that runs image classification evaluation.\n    \"\"\"\n\n    def __init__(\n        self,\n        config: ClassifierEvaluatorConfig,\n    ):\n        \"\"\"\n        Initialize the CustomEvaluator.\n        Args:\n            config: Configuration object containing all evaluation parameters\n        \"\"\"\n\n        self.config = config\n\n        tokenizer = get_tokenizer(self.config.model_name_for_tokenizer)\n        image_processor = get_image_processor(self.config.model_name_for_tokenizer)\n\n        self.renderer = renderers.get_renderer(\n            name=self.config.renderer_name, tokenizer=tokenizer, image_processor=image_processor\n        )\n\n        dataset = datasets.load_dataset(self.config.dataset)\n        dataset = cast(datasets.DatasetDict, dataset)\n        self.dataset = dataset[self.config.dataset_split]\n\n        self.shuffled_dataset = self.dataset.shuffle(seed=0)\n        self.class_labels = self.dataset.features[self.config.label_column_name]\n\n    def get_class_name(self, label: str) -> str:\n        \"\"\"\n        Helper function to clean up the original class name.\n        \"\"\"\n\n        return label.replace(\"_\", \" \").replace(\".\", \" \").replace(\"-\", \" \").lower()\n\n    def build_generation_prompt(\n        self,\n        example: dict[str, Any],\n    ) -> tinker.ModelInput:\n        \"\"\"\n        Generate an input to prompt the model.\n        \"\"\"\n\n        image = example[self.config.image_column_name]\n        pil_image: Image.Image | None = None\n\n        if isinstance(image, dict) and \"bytes\" in image:\n            pil_image = Image.open(io.BytesIO(image[\"bytes\"]))\n\n        elif isinstance(image, Image.Image):\n            pil_image = cast(Image.Image, image)\n\n        # If the dataset cannot be loaded\n        if pil_image is None:\n            raise AssertionError(f\"Unable to interpret {image} as an image\")\n\n        pil_image = resize_image(image=pil_image, max_size=self.config.max_image_size)\n\n        content_parts = [\n            ImagePart(type=\"image\", image=pil_image),\n            TextPart(type=\"text\", text=\"What is the name of the subject in this photo?\"),\n        ]\n\n        messages = [\n            Message(role=\"user\", content=content_parts),\n        ]\n\n        return self.renderer.build_generation_prompt(\n            messages=messages, role=\"assistant\", prefill=\"The subject in this photo is:\"\n        )\n\n    async def generate_output(\n        self,\n        model_input: tinker.ModelInput,\n        sampling_client: tinker.SamplingClient,\n        sampling_params: types.SamplingParams,\n    ) -> ClassifierOutput:\n        \"\"\"\n        Generate a completion and extract the class name from the model.\n        \"\"\"\n\n        # Generate response\n        r: types.SampleResponse = await sampling_client.sample_async(\n            prompt=model_input, num_samples=1, sampling_params=sampling_params\n        )\n        tokens: list[int] = r.sequences[0].tokens\n        response = self.renderer.parse_response(tokens)[0]\n\n        predicted_class_name = get_text_content(response).split(\":\")[-1].strip().lower()\n\n        return ClassifierOutput(predicted_class_name=predicted_class_name)\n\n    def get_metrics_for_output(\n        self, example: dict[str, Any], classifier_output: ClassifierOutput\n    ) -> dict[str, float]:\n        \"\"\"\n        Score the class name predicted by the model.\n        \"\"\"\n\n        predicted_class_name = classifier_output[\"predicted_class_name\"]\n        class_label = example[self.config.label_column_name]\n        class_label_name = self.get_class_name(self.class_labels.int2str(class_label))\n\n        return {\"accuracy\": float(predicted_class_name == class_label_name)}\n\n    async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:\n        \"\"\"\n        Evaluate a vision-language model as an image classifier.\n\n        Args:\n            sampling_client: The sampling client to evaluate\n\n        Returns:\n            Dictionary of metrics from evaluation\n\n        \"\"\"\n\n        sampling_params = types.SamplingParams(\n            max_tokens=self.config.max_tokens,\n            temperature=self.config.temperature,\n            top_p=self.config.top_p,\n            top_k=self.config.top_k,\n            stop=self.renderer.get_stop_sequences(),\n        )\n\n        num_examples = min(\n            len(self.shuffled_dataset), self.config.n_eval or len(self.shuffled_dataset)\n        )\n\n        # Limit concurrent sampling tasks\n        semaphore = asyncio.Semaphore(self.config.max_parallel_tasks)\n\n        async def bounded_generate_output(example: dict[str, Any]) -> ClassifierOutput:\n            async with semaphore:\n                return await self.generate_output(\n                    self.build_generation_prompt(example), sampling_client, sampling_params\n                )\n\n        # Sample from the model in parallel\n        async_tasks = []\n\n        logger.info(\n            f\"Submitting {num_examples} sampling tasks (max {self.config.max_parallel_tasks} parallel)\"\n        )\n        for example_id in range(num_examples):\n            example = self.shuffled_dataset[example_id]\n\n            # Prepare model input for sampling, generate\n            async_tasks.append(asyncio.create_task(bounded_generate_output(example)))\n\n        # Wait for the tinker API to return the sampled completions\n        with timed(\"sample outputs\", {}):\n            outputs = await asyncio.gather(*async_tasks)\n\n        # Aggregate metrics for each example\n        metrics_per_example = []\n\n        logger.info(f\"Evaluating {num_examples} sampled responses\")\n        for example_id in range(num_examples):\n            example = self.shuffled_dataset[example_id]\n            output = outputs[example_id]\n\n            # Evaluate the model response\n            metrics = self.get_metrics_for_output(example, output)\n            metrics_per_example.append(metrics)\n\n        # aggregate the performance metrics\n        aggregated_metrics = {\n            key: np.mean([example[key] for example in metrics_per_example]).item()\n            for key in metrics_per_example[0]\n        }\n\n        return aggregated_metrics\n\n\n@chz.chz\nclass Caltech101EvaluatorBuilder:\n    \"\"\"\n    Configuration for classifier evaluation.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int | None = None\n    max_parallel_tasks: int = 128\n\n    max_image_size: int = 480\n\n    def __call__(self) -> ClassifierEvaluator:\n        config = ClassifierEvaluatorConfig(\n            dataset=\"dpdl-benchmark/caltech101\",\n            dataset_split=\"test\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            temperature=self.temperature,\n            max_tokens=self.max_tokens,\n            top_p=self.top_p,\n            top_k=self.top_k,\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            n_eval=self.n_eval,\n            max_parallel_tasks=self.max_parallel_tasks,\n            max_image_size=self.max_image_size,\n        )\n\n        return ClassifierEvaluator(config)\n\n\n@chz.chz\nclass Flowers102EvaluatorBuilder:\n    \"\"\"\n    Configuration for classifier evaluation.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int | None = None\n    max_parallel_tasks: int = 128\n\n    max_image_size: int = 480\n\n    def __call__(self) -> ClassifierEvaluator:\n        config = ClassifierEvaluatorConfig(\n            dataset=\"dpdl-benchmark/oxford_flowers102\",\n            dataset_split=\"test\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            temperature=self.temperature,\n            max_tokens=self.max_tokens,\n            top_p=self.top_p,\n            top_k=self.top_k,\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            n_eval=self.n_eval,\n            max_parallel_tasks=self.max_parallel_tasks,\n            max_image_size=self.max_image_size,\n        )\n\n        return ClassifierEvaluator(config)\n\n\n@chz.chz\nclass OxfordPetsEvaluatorBuilder:\n    \"\"\"\n    Configuration for classifier evaluation.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int | None = None\n    max_parallel_tasks: int = 128\n\n    max_image_size: int = 480\n\n    def __call__(self) -> ClassifierEvaluator:\n        config = ClassifierEvaluatorConfig(\n            dataset=\"dpdl-benchmark/oxford_iiit_pet\",\n            dataset_split=\"test\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            temperature=self.temperature,\n            max_tokens=self.max_tokens,\n            top_p=self.top_p,\n            top_k=self.top_k,\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            n_eval=self.n_eval,\n            max_parallel_tasks=self.max_parallel_tasks,\n            max_image_size=self.max_image_size,\n        )\n\n        return ClassifierEvaluator(config)\n\n\n@chz.chz\nclass StanfordCarsEvaluatorBuilder:\n    \"\"\"\n    Configuration for classifier evaluation.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n    renderer_name: str\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int | None = None\n    max_parallel_tasks: int = 128\n\n    max_image_size: int = 480\n\n    def __call__(self) -> ClassifierEvaluator:\n        config = ClassifierEvaluatorConfig(\n            dataset=\"tanganke/stanford_cars\",\n            dataset_split=\"test\",\n            renderer_name=self.renderer_name,\n            model_name_for_tokenizer=self.model_name_for_tokenizer,\n            temperature=self.temperature,\n            max_tokens=self.max_tokens,\n            top_p=self.top_p,\n            top_k=self.top_k,\n            image_column_name=\"image\",\n            label_column_name=\"label\",\n            n_eval=self.n_eval,\n            max_parallel_tasks=self.max_parallel_tasks,\n            max_image_size=self.max_image_size,\n        )\n\n        return ClassifierEvaluator(config)\n\n\nEVALUATORS = {\n    \"caltech101\": Caltech101EvaluatorBuilder,\n    \"flowers102\": Flowers102EvaluatorBuilder,\n    \"pets\": OxfordPetsEvaluatorBuilder,\n    \"cars\": StanfordCarsEvaluatorBuilder,\n}\n\n\ndef get_evaluator_builder(\n    dataset: str,\n    model_name_for_tokenizer: str,\n    renderer_name: str,\n    temperature: float = 0.0,\n    max_tokens: int = 128,\n    top_p: float = 1.0,\n    top_k: int = -1,\n    n_eval: int | None = None,\n    max_parallel_tasks: int = 128,\n    max_image_size: int = 480,\n) -> EvaluatorBuilder:\n    \"\"\"\n    Create a sampling based evaluator for a vlm classifier.\n    \"\"\"\n\n    return EVALUATORS[dataset](\n        model_name_for_tokenizer=model_name_for_tokenizer,\n        renderer_name=renderer_name,\n        temperature=temperature,\n        max_tokens=max_tokens,\n        top_p=top_p,\n        top_k=top_k,\n        n_eval=n_eval,\n        max_parallel_tasks=max_parallel_tasks,\n        max_image_size=max_image_size,\n    )\n\n\n@chz.chz\nclass EvalConfig:\n    \"\"\"\n    Config for launching evaluation on a model checkpoint.\n    \"\"\"\n\n    dataset: str\n    model_path: str\n\n    renderer_name: str | None = None\n    model_name: str | None = None\n\n    # Infrastructure parameters\n    base_url: str | None = None\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int | None = None\n    max_parallel_tasks: int = 128\n\n    max_image_size: int = 480\n\n\ndef run_eval(eval_config: EvalConfig):\n    \"\"\"\n    Launch evaluation on a model checkpoint on an image dataset.\n    \"\"\"\n\n    service_client = tinker.ServiceClient(base_url=eval_config.base_url)\n    sampling_client = service_client.create_sampling_client(model_path=eval_config.model_path)\n\n    rest_client = service_client.create_rest_client()\n    training_run = rest_client.get_training_run_by_tinker_path(eval_config.model_path).result()\n    if eval_config.model_name is not None and eval_config.model_name != training_run.base_model:\n        raise ValueError(\n            f\"Model name {eval_config.model_name} does not match training run base model {training_run.base_model}\"\n        )\n    model_name = eval_config.model_name or training_run.base_model\n    renderer_name = eval_config.renderer_name or checkpoint_utils.get_renderer_name_from_checkpoint(\n        service_client, eval_config.model_path\n    )\n    if renderer_name is None:\n        renderer_name = model_info.get_recommended_renderer_name(model_name)\n    logger.info(f\"Using model: {model_name}\")\n    logger.info(f\"Using renderer: {renderer_name}\")\n\n    evaluator_builder = get_evaluator_builder(\n        dataset=eval_config.dataset,\n        model_name_for_tokenizer=model_name,\n        renderer_name=renderer_name,\n        temperature=eval_config.temperature,\n        max_tokens=eval_config.max_tokens,\n        top_p=eval_config.top_p,\n        top_k=eval_config.top_k,\n        n_eval=eval_config.n_eval,\n        max_parallel_tasks=eval_config.max_parallel_tasks,\n        max_image_size=eval_config.max_image_size,\n    )\n\n    evaluator = evaluator_builder()\n\n    async def main():\n        result = await evaluator(sampling_client)  # type: ignore[arg-type]\n        print(f\"Metrics = {result}\")\n\n    asyncio.run(main())\n\n\nif __name__ == \"__main__\":\n    chz.nested_entrypoint(run_eval)\n"
  },
  {
    "path": "tinker_cookbook/recipes/vlm_classifier/eval_sweep.py",
    "content": "\"\"\"\n\n## VLM Image Classifier\n\nLauncher for evaluating trained image classifiers.\n\n```bash\npython -m tinker_cookbook.recipes.vlm_classifier.eval_sweep \\\n    experiment_dir=$HOME/tinker-experiments output_file=results.json\n```\n\nWith early stopping (use best checkpoint per run based on validation accuracy):\n\n```bash\npython -m tinker_cookbook.recipes.vlm_classifier.eval_sweep \\\n    experiment_dir=$HOME/tinker-experiments output_file=results.json\n```\n\n\"\"\"\n\nimport asyncio\nimport json\nimport logging\nimport re\nfrom pathlib import Path\nfrom typing import Any\n\nimport chz\nimport tinker\n\nfrom tinker_cookbook.checkpoint_utils import (\n    CheckpointRecord,\n    get_last_checkpoint,\n    load_checkpoints_file,\n)\nfrom tinker_cookbook.recipes.vlm_classifier.eval import get_evaluator_builder\n\n# Set up logger\nlogger = logging.getLogger(__name__)\n\n\ndef get_checkpoint_at_step(\n    log_dir: str,\n    step: int,\n    required_key: str = \"sampler_path\",\n) -> CheckpointRecord | None:\n    \"\"\"\n    Get the checkpoint at a specific step from the checkpoints.jsonl file.\n\n    Args:\n        log_dir: The directory containing checkpoints.jsonl.\n        step: The step number to find.\n        required_key: The key to check for in the checkpoint.\n\n    Returns:\n        The checkpoint at the specified step, or None if not found.\n    \"\"\"\n    checkpoints = load_checkpoints_file(log_dir)\n    for checkpoint in checkpoints:\n        if checkpoint.batch == step and checkpoint.has(required_key):\n            logger.info(f\"Found checkpoint at step {step}: {checkpoint}\")\n            return checkpoint\n    logger.warning(f\"No checkpoint found at step {step} with key '{required_key}' in {log_dir}\")\n    return None\n\n\ndef parse_hyperparams_from_experiment_name(experiment_name: str) -> dict[str, Any]:\n    \"\"\"\n    Parse hyperparameters from the experiment directory name.\n\n    Experiment names follow the format from sweep.py:\n    {dataset}-{model_name}-{lora_rank}rank-{learning_rate}lr-{batch_size}batch-{examples_per_class}shot-seed{subset_seed}-{date}\n\n    Example: caltech101-Qwen-Qwen3-VL-235B-A22B-Instruct-32rank-0.0005lr-32batch-4shot-seed0-2025-11-26\n    \"\"\"\n\n    hyperparams: dict[str, Any] = {}\n\n    # Parse dataset: first segment before the first dash\n    hyperparams[\"dataset\"] = experiment_name.split(\"-\")[0]\n\n    # Parse lora_rank: look for pattern like \"32rank\"\n    if match := re.search(r\"-(\\d+)rank-\", experiment_name):\n        hyperparams[\"lora_rank\"] = int(match.group(1))\n\n    # Parse learning_rate: look for pattern like \"0.0005lr\" or \"5e-4lr\"\n    if match := re.search(r\"-([\\d.e+-]+)lr-\", experiment_name):\n        hyperparams[\"learning_rate\"] = float(match.group(1))\n\n    # Parse batch_size: look for pattern like \"32batch\"\n    if match := re.search(r\"-(\\d+)batch\", experiment_name):\n        hyperparams[\"batch_size\"] = int(match.group(1))\n\n    # Parse examples_per_class (shot): look for pattern like \"4shot\"\n    if match := re.search(r\"-(\\d+)shot-\", experiment_name):\n        hyperparams[\"examples_per_class\"] = int(match.group(1))\n\n    # Parse subset_seed: look for pattern like \"seed0\"\n    if match := re.search(r\"-seed(\\d+)-\", experiment_name):\n        hyperparams[\"subset_seed\"] = int(match.group(1))\n\n    # Parse date: look for pattern like \"2025-11-26\" at the end\n    if match := re.search(r\"-(\\d{4}-\\d{2}-\\d{2})$\", experiment_name):\n        hyperparams[\"date\"] = match.group(1)\n\n    return hyperparams\n\n\n@chz.chz\nclass EvalConfig:\n    \"\"\"\n    Config for evaluating all experiments in a sweep directory.\n    \"\"\"\n\n    experiment_dir: str\n    output_file: str\n\n    renderer_name: str = \"qwen3_vl\"\n    model_name: str = \"Qwen/Qwen3-VL-235B-A22B-Instruct\"\n\n    # Infrastructure parameters\n    base_url: str | None = None\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int | None = None\n    max_parallel_tasks: int = 1024\n    max_parallel_evals: int = 5\n\n    max_image_size: int = 480\n\n    # Early stopping: map experiment name to the step of the best checkpoint\n    # If not provided or experiment not in dict, uses the last checkpoint\n    early_stopping_checkpoints: dict[str, int] | None = None\n\n\nasync def evaluate_experiment(\n    experiment_name: str,\n    eval_config: EvalConfig,\n    service_client: tinker.ServiceClient,\n) -> dict[str, Any]:\n    \"\"\"\n    Evaluate a single few-shot image classifier experiment.\n    \"\"\"\n\n    experiment_path = Path(eval_config.experiment_dir) / experiment_name\n    assert experiment_path.is_dir(), f\"Experiment directory does not exist: {experiment_path}\"\n\n    # Load checkpoint: use early stopping step if provided, otherwise use last checkpoint\n    early_stop_step = (\n        eval_config.early_stopping_checkpoints.get(experiment_name)\n        if eval_config.early_stopping_checkpoints\n        else None\n    )\n\n    experiment_path_str = str(experiment_path)\n    if early_stop_step is not None:\n        checkpoint = get_checkpoint_at_step(\n            experiment_path_str, early_stop_step, required_key=\"sampler_path\"\n        )\n        assert checkpoint is not None, (\n            f\"No checkpoint at step {early_stop_step} with sampler_path found in {experiment_path}\"\n        )\n        logger.info(\n            f\"Using early stopping checkpoint at step {early_stop_step} for {experiment_name}\"\n        )\n    else:\n        checkpoint = get_last_checkpoint(experiment_path_str, required_key=\"sampler_path\")\n        assert checkpoint is not None, f\"No checkpoint with sampler_path found in {experiment_path}\"\n\n    # Parse hyperparameters (including dataset) from directory name\n    hyperparams = parse_hyperparams_from_experiment_name(experiment_name)\n    assert \"dataset\" in hyperparams, f\"Unable to parse the dataset name from {experiment_path}\"\n\n    # Create evaluator for this dataset\n    evaluator_builder = get_evaluator_builder(\n        dataset=hyperparams[\"dataset\"],\n        model_name_for_tokenizer=eval_config.model_name,\n        renderer_name=eval_config.renderer_name,\n        temperature=eval_config.temperature,\n        max_tokens=eval_config.max_tokens,\n        top_p=eval_config.top_p,\n        top_k=eval_config.top_k,\n        n_eval=eval_config.n_eval,\n        max_parallel_tasks=eval_config.max_parallel_tasks,\n        max_image_size=eval_config.max_image_size,\n    )\n\n    sampling_client = service_client.create_sampling_client(model_path=checkpoint.sampler_path)\n    metrics = await evaluator_builder()(sampling_client)  # type: ignore[arg-type]\n    return {\n        \"experiment_name\": experiment_name,\n        \"checkpoint_step\": checkpoint.batch,\n        **metrics,\n        **hyperparams,\n    }\n\n\nasync def evaluate_sweep(\n    eval_config: EvalConfig,\n    experiment_names: list[str],\n) -> dict[str, dict[str, Any]]:\n    \"\"\"\n    Evaluate all few-shot image classifier experiments in a sweep directory.\n    \"\"\"\n\n    service_client = tinker.ServiceClient(base_url=eval_config.base_url)\n\n    # Limit concurrent evaluation tasks\n    semaphore = asyncio.Semaphore(eval_config.max_parallel_evals)\n\n    async def bounded_evaluate_experiment(experiment_name: str) -> dict[str, Any]:\n        async with semaphore:\n            return await evaluate_experiment(\n                experiment_name=experiment_name,\n                eval_config=eval_config,\n                service_client=service_client,\n            )\n\n    # Evaluate all experiments in parallel (bounded by semaphore)\n    logger.info(\n        f\"Submitting {len(experiment_names)} eval tasks (max {eval_config.max_parallel_evals} parallel)\"\n    )\n    async_tasks = [\n        asyncio.create_task(bounded_evaluate_experiment(name)) for name in experiment_names\n    ]\n\n    results = await asyncio.gather(*async_tasks)\n    return {metrics[\"experiment_name\"]: metrics for metrics in results}\n\n\ndef run_eval_sweep(eval_config: EvalConfig):\n    \"\"\"\n    Evaluate all few-shot image classifier experiments in a sweep directory.\n    \"\"\"\n\n    logging.basicConfig(level=logging.INFO)\n\n    experiment_dir = Path(eval_config.experiment_dir)\n    if not experiment_dir.is_dir():\n        raise ValueError(f\"Experiment directory does not exist: {eval_config.experiment_dir}\")\n\n    # Find all experiment subdirectories\n    experiment_names = sorted(d.name for d in experiment_dir.iterdir() if d.is_dir())\n\n    logger.info(\n        f\"Found {len(experiment_names)} experiment directories in {eval_config.experiment_dir}\"\n    )\n    classifier_results_json = asyncio.run(\n        evaluate_sweep(\n            eval_config=eval_config,\n            experiment_names=experiment_names,\n        )\n    )\n\n    # Save results to output file\n    Path(eval_config.output_file).resolve().parent.mkdir(parents=True, exist_ok=True)\n    with open(eval_config.output_file, \"w\") as f:\n        json.dump(classifier_results_json, f, indent=2)\n\n    logger.info(f\"Saved classifier results to {eval_config.output_file}\")\n    print(json.dumps(classifier_results_json, indent=2))\n\n\nif __name__ == \"__main__\":\n    chz.nested_entrypoint(run_eval_sweep)\n"
  },
  {
    "path": "tinker_cookbook/recipes/vlm_classifier/sweep.py",
    "content": "\"\"\"\n\n## VLM Image Classifier\n\nLauncher for training image classifiers based on VLMs.\n\n```bash\npython -m tinker_cookbook.recipes.vlm_classifier.sweep experiment_dir=./sweep model_name=Qwen/Qwen3-VL-30B-A3B-Instruct\n```\n\n\"\"\"\n\nimport asyncio\nfrom concurrent.futures import ProcessPoolExecutor\nfrom datetime import datetime\nfrom itertools import product\nfrom pathlib import Path\n\nimport chz\n\nfrom tinker_cookbook import cli_utils\nfrom tinker_cookbook.recipes.vlm_classifier.data import get_dataset_builder\nfrom tinker_cookbook.recipes.vlm_classifier.eval import get_evaluator_builder\nfrom tinker_cookbook.renderers import TrainOnWhat\nfrom tinker_cookbook.supervised import train\nfrom tinker_cookbook.utils.lr_scheduling import LRSchedule\n\n\n@chz.chz\nclass ExperimentConfig:\n    \"\"\"\n    Experiments for few-shot image classification with VLMs.\n    \"\"\"\n\n    experiment_dir: str\n\n    dataset: str = \"caltech101\"\n    renderer_name: str = \"qwen3_vl\"\n    model_name: str = \"Qwen/Qwen3-VL-235B-A22B-Instruct\"\n\n    # Infrastructure parameters\n    base_url: str | None = None\n\n    # Training parameters\n    learning_rate: float = 5e-4\n    num_epochs: int = 1\n    lr_schedule: LRSchedule = \"cosine\"\n\n    # Model parameters\n    lora_rank: int = 32\n\n    # Checkpointing and evaluation\n    save_every: int = 50\n    eval_every: int = 50\n    infrequent_eval_every: int = 100\n\n    # Logging parameters\n    wandb_project: str | None = None\n\n    train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE\n\n    num_repeats: float = 10\n    batch_size: int = 32\n    max_length: int = 8192\n\n    examples_per_class: int | None = None\n    subset_seed: int = 0\n\n    run_nll_evaluator: bool = False\n    run_sampling_evaluator: bool = True\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int = 256\n\n    max_steps: int | None = None\n\n\ndef run_experiment(experiment_config: ExperimentConfig):\n    \"\"\"\n    Run a supervised training experiment for a vlm classifier.\n    \"\"\"\n\n    # build full config\n    model_name = experiment_config.model_name.replace(\"/\", \"-\")\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d\")\n\n    # Include examples_per_class and subset_seed in run name if doing few-shot\n    shot_suffix = (\n        f\"-{experiment_config.examples_per_class}shot-seed{experiment_config.subset_seed}\"\n        if experiment_config.examples_per_class\n        else \"\"\n    )\n    experiment_name = f\"{experiment_config.dataset}-{model_name}-{experiment_config.lora_rank}rank-{experiment_config.learning_rate}lr-{experiment_config.batch_size}batch{shot_suffix}-{date_and_time}\"\n\n    experiment_path = str(Path(experiment_config.experiment_dir) / experiment_name)\n    cli_utils.check_log_dir(experiment_path, behavior_if_exists=\"delete\")\n\n    dataset_builder = get_dataset_builder(\n        dataset=experiment_config.dataset,\n        model_name_for_tokenizer=experiment_config.model_name,\n        renderer_name=experiment_config.renderer_name,\n        num_repeats=experiment_config.num_repeats,\n        batch_size=experiment_config.batch_size,\n        max_length=experiment_config.max_length,\n        train_on_what=experiment_config.train_on_what,\n        examples_per_class=experiment_config.examples_per_class,\n        subset_seed=experiment_config.subset_seed,\n        run_nll_evaluator=experiment_config.run_nll_evaluator,\n    )\n\n    evaluator_builders = []\n    if experiment_config.run_sampling_evaluator:\n        evaluator_builders = [\n            get_evaluator_builder(\n                dataset=experiment_config.dataset,\n                model_name_for_tokenizer=experiment_config.model_name,\n                renderer_name=experiment_config.renderer_name,\n                temperature=experiment_config.temperature,\n                max_tokens=experiment_config.max_tokens,\n                top_p=experiment_config.top_p,\n                top_k=experiment_config.top_k,\n                n_eval=experiment_config.n_eval,\n            )\n        ]\n\n    config = train.Config(\n        log_path=experiment_path,\n        model_name=experiment_config.model_name,\n        renderer_name=experiment_config.renderer_name,\n        dataset_builder=dataset_builder,\n        evaluator_builders=evaluator_builders,\n        infrequent_evaluator_builders=[],\n        learning_rate=experiment_config.learning_rate,\n        lr_schedule=experiment_config.lr_schedule,\n        num_epochs=experiment_config.num_epochs,\n        base_url=experiment_config.base_url,\n        wandb_project=experiment_config.wandb_project,\n        wandb_name=experiment_name,\n        lora_rank=experiment_config.lora_rank,\n        save_every=experiment_config.save_every,\n        eval_every=experiment_config.eval_every,\n        infrequent_eval_every=experiment_config.infrequent_eval_every,\n        max_steps=experiment_config.max_steps,\n    )\n\n    asyncio.run(train.main(config))\n\n\n@chz.chz\nclass SweepConfig:\n    \"\"\"\n    Configuration for the sweep.\n    \"\"\"\n\n    experiment_dir: str\n\n    renderer_name: str = \"qwen3_vl\"\n    model_name: str = \"Qwen/Qwen3-VL-235B-A22B-Instruct\"\n\n    datasets: list[str] = chz.field(default_factory=lambda: [\"caltech101\"])\n    examples_per_class: list[int] = chz.field(default_factory=lambda: [1, 2, 4, 8, 16])\n\n    learning_rate: float = 1e-4\n    num_epochs: int = 1\n    lr_schedule: LRSchedule = \"constant\"\n\n    lora_rank: int = 32\n\n    num_repeats: float = 10\n    batch_size: int = 32\n    max_length: int = 8192\n\n    run_nll_evaluator: bool = False\n    run_sampling_evaluator: bool = True\n\n    base_url: str | None = None\n    wandb_project: str | None = None\n\n    # Number of experiments to run in parallel\n    num_parallel: int = 5\n\n\n# Adjust the number of epochs based on the amount of data\nEXAMPLES_TO_MULTIPLIER = {16: 1, 8: 2, 4: 4, 2: 8, 1: 16}\n\n\ndef run_sweep(sweep_config: SweepConfig):\n    \"\"\"\n    Run all experiments in parallel using ProcessPoolExecutor.\n    \"\"\"\n\n    experiment_configs = [\n        ExperimentConfig(\n            experiment_dir=sweep_config.experiment_dir,\n            model_name=sweep_config.model_name,\n            renderer_name=sweep_config.renderer_name,\n            dataset=target_dataset,\n            learning_rate=sweep_config.learning_rate,\n            num_epochs=sweep_config.num_epochs,\n            lr_schedule=sweep_config.lr_schedule,\n            lora_rank=sweep_config.lora_rank,\n            num_repeats=EXAMPLES_TO_MULTIPLIER[examples_per_class] * sweep_config.num_repeats,\n            batch_size=sweep_config.batch_size,\n            max_length=sweep_config.max_length,\n            examples_per_class=examples_per_class,\n            wandb_project=sweep_config.wandb_project,\n            base_url=sweep_config.base_url,\n            run_nll_evaluator=sweep_config.run_nll_evaluator,\n            run_sampling_evaluator=sweep_config.run_sampling_evaluator,\n        )\n        for target_dataset, examples_per_class in product(\n            sweep_config.datasets, sweep_config.examples_per_class\n        )\n    ]\n\n    print(\n        f\"Running {len(experiment_configs)} experiments with {sweep_config.num_parallel} parallel workers\"\n    )\n\n    with ProcessPoolExecutor(max_workers=sweep_config.num_parallel) as executor:\n        futures = [executor.submit(run_experiment, config) for config in experiment_configs]\n        results = [f.result() for f in futures]\n        print(f\"{len(results)} experiments finished running\")\n\n\nif __name__ == \"__main__\":\n    chz.nested_entrypoint(run_sweep)\n"
  },
  {
    "path": "tinker_cookbook/recipes/vlm_classifier/train.py",
    "content": "\"\"\"\n\n## VLM Image Classifier\n\nLauncher for training image classifiers based on VLMs.\n\n```bash\npython -m tinker_cookbook.recipes.vlm_classifier.train experiment_dir=./vlm_classifier\n```\n\n\"\"\"\n\nimport asyncio\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Literal\n\nimport chz\n\nfrom tinker_cookbook import cli_utils\nfrom tinker_cookbook.recipes.vlm_classifier.data import get_dataset_builder\nfrom tinker_cookbook.recipes.vlm_classifier.eval import get_evaluator_builder\nfrom tinker_cookbook.renderers import TrainOnWhat\nfrom tinker_cookbook.supervised import train\nfrom tinker_cookbook.utils.lr_scheduling import LRSchedule\n\n\n@chz.chz\nclass ExperimentConfig:\n    \"\"\"\n    Experiments for few-shot image classification with VLMs.\n    \"\"\"\n\n    experiment_dir: str\n    load_checkpoint_path: str | None = None\n\n    dataset: str = \"caltech101\"\n\n    renderer_name: str = \"qwen3_vl\"\n    model_name: str = \"Qwen/Qwen3-VL-235B-A22B-Instruct\"\n\n    # Infrastructure parameters\n    base_url: str | None = None\n    behavior_if_log_dir_exists: Literal[\"delete\", \"resume\", \"ask\", \"raise\"] = \"ask\"\n\n    # Training parameters\n    learning_rate: float = 5e-4\n    num_epochs: int = 3\n    lr_schedule: LRSchedule = \"cosine\"\n\n    # Model parameters\n    lora_rank: int = 32\n\n    # Checkpointing and evaluation\n    save_every: int = 20\n    eval_every: int = 20\n    infrequent_eval_every: int = 100\n\n    # Logging parameters\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE\n\n    num_repeats: float = 1\n    batch_size: int = 32\n    max_length: int = 8192\n\n    examples_per_class: int | None = None\n    subset_seed: int = 0\n\n    run_nll_evaluator: bool = True\n    run_sampling_evaluator: bool = True\n\n    temperature: float = 0.0\n    max_tokens: int = 128\n    top_p: float = 1.0\n    top_k: int = -1\n\n    n_eval: int = 128\n\n    max_steps: int | None = None\n\n\ndef run_experiment(experiment_config: ExperimentConfig):\n    \"\"\"\n    Launcher for training an image classifier based on a VLM on a custom vision dataset.\n    \"\"\"\n\n    # build full config\n    model_name = experiment_config.model_name.replace(\"/\", \"-\")\n    date_and_time = datetime.now().strftime(\"%Y-%m-%d\")\n\n    # Include examples_per_class and subset_seed in run name if doing few-shot\n    shot_suffix = (\n        f\"-{experiment_config.examples_per_class}shot-seed{experiment_config.subset_seed}\"\n        if experiment_config.examples_per_class\n        else \"\"\n    )\n    experiment_name = f\"{experiment_config.dataset}-{model_name}-{experiment_config.lora_rank}rank-{experiment_config.learning_rate}lr-{experiment_config.batch_size}batch{shot_suffix}-{date_and_time}\"\n\n    experiment_path = str(Path(experiment_config.experiment_dir) / experiment_name)\n    cli_utils.check_log_dir(\n        experiment_path, behavior_if_exists=experiment_config.behavior_if_log_dir_exists\n    )\n\n    dataset_builder = get_dataset_builder(\n        dataset=experiment_config.dataset,\n        model_name_for_tokenizer=experiment_config.model_name,\n        renderer_name=experiment_config.renderer_name,\n        num_repeats=experiment_config.num_repeats,\n        batch_size=experiment_config.batch_size,\n        max_length=experiment_config.max_length,\n        train_on_what=experiment_config.train_on_what,\n        examples_per_class=experiment_config.examples_per_class,\n        subset_seed=experiment_config.subset_seed,\n        run_nll_evaluator=experiment_config.run_nll_evaluator,\n    )\n\n    evaluator_builders = []\n    if experiment_config.run_sampling_evaluator:\n        evaluator_builders = [\n            get_evaluator_builder(\n                dataset=experiment_config.dataset,\n                model_name_for_tokenizer=experiment_config.model_name,\n                renderer_name=experiment_config.renderer_name,\n                temperature=experiment_config.temperature,\n                max_tokens=experiment_config.max_tokens,\n                top_p=experiment_config.top_p,\n                top_k=experiment_config.top_k,\n                n_eval=experiment_config.n_eval,\n            )\n        ]\n\n    config = train.Config(\n        log_path=experiment_path,\n        model_name=experiment_config.model_name,\n        renderer_name=experiment_config.renderer_name,\n        load_checkpoint_path=experiment_config.load_checkpoint_path,\n        dataset_builder=dataset_builder,\n        evaluator_builders=evaluator_builders,\n        infrequent_evaluator_builders=[],\n        learning_rate=experiment_config.learning_rate,\n        lr_schedule=experiment_config.lr_schedule,\n        num_epochs=experiment_config.num_epochs,\n        base_url=experiment_config.base_url,\n        wandb_project=experiment_config.wandb_project,\n        wandb_name=experiment_config.wandb_name or experiment_name,\n        lora_rank=experiment_config.lora_rank,\n        save_every=experiment_config.save_every,\n        eval_every=experiment_config.eval_every,\n        infrequent_eval_every=experiment_config.infrequent_eval_every,\n        max_steps=experiment_config.max_steps,\n    )\n\n    asyncio.run(train.main(config))\n\n\nif __name__ == \"__main__\":\n    experiment_config = chz.entrypoint(ExperimentConfig)\n    run_experiment(experiment_config)\n"
  },
  {
    "path": "tinker_cookbook/renderers/README.md",
    "content": "Check out the [rendering docs](https://tinker-docs.thinkingmachines.ai/rendering) for more details on the rendering system, and the [sequence extension docs](https://tinker-docs.thinkingmachines.ai/rl/sequence-extension) for more details on the sequence extension property.\n\nSome tips for development of renderers:\n\n- As it's hard to audit all the parsing and prompt-building logic, focus on writing tests for properties that these functions should have:\n\n    - Exact correspondence between `build_generation_prompt` and `apply_chat_template` from HuggingFace transformers\n    - Correspondence between `build_supervised_example` and `build_generation_prompt`: if you you turn the conversation into a supervised example (tokens and weights), and split the tokens into observation and action, then the observation tokens should match the tokens that would be generated by `build_generation_prompt` for the same conversation prefix (all but the last assistant message).\n    - Correspondence between `build_supervised_example` and `parse_response`: if you turn the conversation into a supervised example as described above, and parse the action tokens back into a message, you should get the same final assistant message back that you started with.\n    - Sequence extension property: for applicable models, given a series of message user1, assistant1, user2, assistant2, ..., the observation from user1 should be a prefix of the observation from user2, and so on.\n\n- For LLM assisted development, do some web research to put together specification of the token-level formatting conventions:\n\n    - Do a local checkout of vLLM or SGLang and tell the LLM to use it for reference for determining token formats\n    - Tell a web research agent to look up the token-level formatting conventions for the model in question (using developer docs, tech reports, chat templates, and other sources), and compile a markdown-formatted report that describes the rendering and parsing behavior.\n"
  },
  {
    "path": "tinker_cookbook/renderers/__init__.py",
    "content": "\"\"\"\nRenderers for converting message lists into training and sampling prompts.\n\nUse viz_sft_dataset to visualize the output of different renderers. E.g.,\n    python -m tinker_cookbook.supervised.viz_sft_dataset dataset_path=Tulu3Builder renderer_name=role_colon\n\"\"\"\n\nfrom collections.abc import Callable\nfrom typing import Any\n\nfrom tinker_cookbook.exceptions import RendererError\nfrom tinker_cookbook.image_processing_utils import ImageProcessor\n\n# Types and utilities used by external code\nfrom tinker_cookbook.renderers.base import (\n    # Content part types\n    ContentPart,\n    ImagePart,\n    Message,\n    # Streaming types\n    MessageDelta,\n    # Renderer base\n    RenderContext,\n    Renderer,\n    Role,\n    StreamingMessageHeader,\n    StreamingTextDelta,\n    StreamingThinkingDelta,\n    TextPart,\n    ThinkingPart,\n    ToolCall,\n    ToolSpec,\n    TrainOnWhat,\n    Utf8TokenDecoder,\n    # Utility functions\n    ensure_text,\n    format_content_as_string,\n    get_text_content,\n    parse_content_blocks,\n)\n\n# Renderer classes used directly by tests\nfrom tinker_cookbook.renderers.deepseek_v3 import DeepSeekV3ThinkingRenderer\nfrom tinker_cookbook.renderers.gpt_oss import GptOssRenderer\nfrom tinker_cookbook.renderers.qwen3 import Qwen3Renderer\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n# Global registry for custom renderer factories\n_CUSTOM_RENDERER_REGISTRY: dict[str, Callable[[Tokenizer, Any], Renderer]] = {}\n\n\ndef register_renderer(\n    name: str,\n    factory: Callable[[Tokenizer, Any], Renderer],\n) -> None:\n    \"\"\"Register a custom renderer factory.\n\n    Args:\n        name: The renderer name\n        factory: A callable that takes (tokenizer, image_processor) and returns a Renderer.\n\n    Example:\n        def my_renderer_factory(tokenizer, image_processor=None):\n            return MyCustomRenderer(tokenizer)\n\n        register_renderer(\"Foo/foo_renderer\", my_renderer_factory)\n    \"\"\"\n    _CUSTOM_RENDERER_REGISTRY[name] = factory\n\n\ndef get_registered_renderer_names() -> list[str]:\n    \"\"\"Return a list of all registered custom renderer names.\"\"\"\n    return list(_CUSTOM_RENDERER_REGISTRY.keys())\n\n\ndef is_renderer_registered(name: str) -> bool:\n    \"\"\"Check if a renderer name is registered.\"\"\"\n    return name in _CUSTOM_RENDERER_REGISTRY\n\n\ndef unregister_renderer(name: str) -> bool:\n    \"\"\"Unregister a custom renderer factory.\n\n    Args:\n        name: The renderer name to unregister.\n\n    Returns:\n        True if the renderer was unregistered, False if it wasn't registered.\n    \"\"\"\n    if name in _CUSTOM_RENDERER_REGISTRY:\n        del _CUSTOM_RENDERER_REGISTRY[name]\n        return True\n    return False\n\n\ndef get_renderer(\n    name: str,\n    tokenizer: Tokenizer,\n    image_processor: ImageProcessor | None = None,\n    model_name: str | None = None,\n) -> Renderer:\n    \"\"\"Factory function to create renderers by name.\n\n    Args:\n        name: Renderer name. Supported values:\n            - \"role_colon\": Simple role:content format\n            - \"llama3\": Llama 3 chat format\n            - \"qwen3\": Qwen3 with thinking enabled\n            - \"qwen3_vl\": Qwen3 vision-language with thinking\n            - \"qwen3_vl_instruct\": Qwen3 vision-language instruct (no thinking)\n            - \"qwen3_disable_thinking\": Qwen3 with thinking disabled\n            - \"qwen3_instruct\": Qwen3 instruct 2507 (no thinking)\n            - \"qwen3_5\": Qwen3.5 VL with thinking\n            - \"qwen3_5_disable_thinking\": Qwen3.5 VL with thinking disabled\n            - \"deepseekv3\": DeepSeek V3 (defaults to non-thinking mode)\n            - \"deepseekv3_disable_thinking\": DeepSeek V3 non-thinking (alias)\n            - \"deepseekv3_thinking\": DeepSeek V3 thinking mode\n            - \"kimi_k2\": Kimi K2 Thinking format\n            - \"kimi_k25\": Kimi K2.5 with thinking enabled\n            - \"kimi_k25_disable_thinking\": Kimi K2.5 with thinking disabled\n            - \"nemotron3\": Nemotron-3 with thinking enabled\n            - \"nemotron3_disable_thinking\": Nemotron-3 with thinking disabled\n            - \"gpt_oss_no_sysprompt\": GPT-OSS without system prompt\n            - \"gpt_oss_low_reasoning\": GPT-OSS with low reasoning\n            - \"gpt_oss_medium_reasoning\": GPT-OSS with medium reasoning\n            - \"gpt_oss_high_reasoning\": GPT-OSS with high reasoning\n            - Custom renderers registered via register_renderer()\n        tokenizer: The tokenizer to use.\n        image_processor: Required for VL renderers.\n        model_name: Model name for pickle metadata. If None, falls back to\n            ``tokenizer.name_or_path``. Provide this explicitly when the tokenizer\n            was loaded with a remapped name (e.g., Llama 3 models).\n\n    Returns:\n        A Renderer instance.\n\n    Raises:\n        ValueError: If the renderer name is unknown.\n        AssertionError: If a VL renderer is requested without an image_processor.\n    \"\"\"\n\n    def _stamp_pickle_metadata(renderer: Renderer) -> Renderer:\n        \"\"\"Stamp renderer with metadata needed for pickle support.\"\"\"\n        renderer._renderer_name = name\n        renderer._model_name = model_name if model_name is not None else tokenizer.name_or_path\n        renderer._has_image_processor = image_processor is not None\n        return renderer\n\n    # Check custom registry first\n    if (factory := _CUSTOM_RENDERER_REGISTRY.get(name)) is not None:\n        return _stamp_pickle_metadata(factory(tokenizer, image_processor))\n\n    # Import renderer classes lazily to avoid circular imports and keep exports minimal\n    from tinker_cookbook.renderers.deepseek_v3 import DeepSeekV3DisableThinkingRenderer\n    from tinker_cookbook.renderers.gpt_oss import GptOssRenderer\n    from tinker_cookbook.renderers.kimi_k2 import KimiK2Renderer\n    from tinker_cookbook.renderers.kimi_k25 import KimiK25DisableThinkingRenderer, KimiK25Renderer\n    from tinker_cookbook.renderers.llama3 import Llama3Renderer\n    from tinker_cookbook.renderers.nemotron3 import (\n        Nemotron3DisableThinkingRenderer,\n        Nemotron3Renderer,\n    )\n    from tinker_cookbook.renderers.qwen3 import (\n        Qwen3DisableThinkingRenderer,\n        Qwen3InstructRenderer,\n        Qwen3VLInstructRenderer,\n        Qwen3VLRenderer,\n    )\n    from tinker_cookbook.renderers.qwen3_5 import Qwen3_5DisableThinkingRenderer, Qwen3_5Renderer\n    from tinker_cookbook.renderers.role_colon import RoleColonRenderer\n\n    renderer: Renderer\n    if name == \"role_colon\":\n        renderer = RoleColonRenderer(tokenizer)\n    elif name == \"llama3\":\n        renderer = Llama3Renderer(tokenizer)\n    elif name == \"qwen3\":\n        renderer = Qwen3Renderer(tokenizer)\n    elif name == \"qwen3_vl\":\n        assert image_processor is not None, \"qwen3_vl renderer requires an image_processor\"\n        renderer = Qwen3VLRenderer(tokenizer, image_processor)\n    elif name == \"qwen3_vl_instruct\":\n        assert image_processor is not None, \"qwen3_vl_instruct renderer requires an image_processor\"\n        renderer = Qwen3VLInstructRenderer(tokenizer, image_processor)\n    elif name == \"qwen3_disable_thinking\":\n        renderer = Qwen3DisableThinkingRenderer(tokenizer)\n    elif name == \"qwen3_instruct\":\n        renderer = Qwen3InstructRenderer(tokenizer)\n    elif name == \"qwen3_5\":\n        renderer = Qwen3_5Renderer(tokenizer, image_processor=image_processor)\n    elif name == \"qwen3_5_disable_thinking\":\n        renderer = Qwen3_5DisableThinkingRenderer(tokenizer, image_processor=image_processor)\n    elif name == \"deepseekv3\":\n        # Default to non-thinking mode (matches HF template default behavior)\n        renderer = DeepSeekV3DisableThinkingRenderer(tokenizer)\n    elif name == \"deepseekv3_disable_thinking\":\n        # Alias for backward compatibility\n        renderer = DeepSeekV3DisableThinkingRenderer(tokenizer)\n    elif name == \"deepseekv3_thinking\":\n        renderer = DeepSeekV3ThinkingRenderer(tokenizer)\n    elif name == \"kimi_k2\":\n        renderer = KimiK2Renderer(tokenizer)\n    elif name == \"kimi_k25\":\n        renderer = KimiK25Renderer(tokenizer, image_processor=image_processor)\n    elif name == \"kimi_k25_disable_thinking\":\n        renderer = KimiK25DisableThinkingRenderer(tokenizer, image_processor=image_processor)\n    elif name == \"nemotron3\":\n        renderer = Nemotron3Renderer(tokenizer)\n    elif name == \"nemotron3_disable_thinking\":\n        renderer = Nemotron3DisableThinkingRenderer(tokenizer)\n    elif name == \"gpt_oss_no_sysprompt\":\n        renderer = GptOssRenderer(tokenizer, use_system_prompt=False)\n    elif name == \"gpt_oss_low_reasoning\":\n        renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"low\")\n    elif name == \"gpt_oss_medium_reasoning\":\n        renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n    elif name == \"gpt_oss_high_reasoning\":\n        renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"high\")\n    else:\n        raise RendererError(\n            f\"Unknown renderer: {name}. If this is a custom renderer, please register it via register_renderer().\"\n        )\n\n    return _stamp_pickle_metadata(renderer)\n\n\n__all__ = [\n    # Types\n    \"ContentPart\",\n    \"ImagePart\",\n    \"Message\",\n    \"Role\",\n    \"TextPart\",\n    \"ThinkingPart\",\n    \"ToolCall\",\n    \"ToolSpec\",\n    # Streaming types\n    \"MessageDelta\",\n    \"StreamingMessageHeader\",\n    \"StreamingTextDelta\",\n    \"StreamingThinkingDelta\",\n    \"Utf8TokenDecoder\",\n    # Renderer base\n    \"RenderContext\",\n    \"Renderer\",\n    \"TrainOnWhat\",\n    # Utility functions\n    \"ensure_text\",\n    \"format_content_as_string\",\n    \"get_text_content\",\n    \"parse_content_blocks\",\n    # Registry\n    \"register_renderer\",\n    \"unregister_renderer\",\n    \"get_registered_renderer_names\",\n    \"is_renderer_registered\",\n    # Factory\n    \"get_renderer\",\n    # Renderer classes (used by tests)\n    \"DeepSeekV3ThinkingRenderer\",\n    \"GptOssRenderer\",\n    \"Qwen3Renderer\",\n]\n"
  },
  {
    "path": "tinker_cookbook/renderers/base.py",
    "content": "\"\"\"\nBase types, utilities, and abstract Renderer class for message rendering.\n\nUse viz_sft_dataset to visualize the output of different renderers. E.g.,\n    python -m tinker_cookbook.supervised.viz_sft_dataset dataset_path=Tulu3Builder renderer_name=role_colon\n\"\"\"\n\nimport io\nimport json\nimport logging\nimport pickle\nimport re\nimport urllib.request\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Callable, Iterator\nfrom dataclasses import dataclass, field\nfrom enum import StrEnum\nfrom typing import (\n    Any,\n    Literal,\n    NotRequired,\n    Protocol,\n    TypedDict,\n    Union,\n)\n\nimport pydantic\nimport tinker\nimport torch\nfrom PIL import Image\n\nfrom tinker_cookbook.exceptions import RendererError\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\nlogger = logging.getLogger(__name__)\n\n# Tool types are based on kosong (https://github.com/MoonshotAI/kosong).\n\n\nclass StrictBase(pydantic.BaseModel):\n    \"\"\"\n    Pydantic base class that's immutable and doesn't silently ignore extra fields.\n    \"\"\"\n\n    model_config = pydantic.ConfigDict(frozen=True, extra=\"forbid\")\n\n    def __str__(self) -> str:\n        return repr(self)\n\n\nclass ToolCall(StrictBase):\n    \"\"\"\n    Structured tool invocation following OpenAI/kosong format.\n\n    This represents a request to invoke a tool/function. The structure follows\n    the OpenAI function calling format for compatibility with various LLM APIs.\n\n    Example:\n        tool_call = ToolCall(\n            function=ToolCall.FunctionBody(\n                name=\"search\",\n                arguments='{\"query_list\": [\"python async\", \"pydantic validation\"]}'\n            ),\n            id=\"call_abc123\"\n        )\n    \"\"\"\n\n    class FunctionBody(pydantic.BaseModel):\n        \"\"\"\n        Tool call function body containing the tool name and arguments.\n\n        The arguments field must be a valid JSON string that will be parsed\n        by the tool implementation.\n        \"\"\"\n\n        name: str\n        \"\"\"The name of the tool to be called.\"\"\"\n        arguments: str\n        \"\"\"Arguments of the tool call in JSON string format.\"\"\"\n\n    type: Literal[\"function\"] = \"function\"\n    \"\"\"Tool call type, must be 'function' for compatibility.\"\"\"\n\n    id: str | None = None\n    \"\"\"Optional unique identifier for tracking this specific tool call.\"\"\"\n\n    function: FunctionBody\n    \"\"\"The function body containing tool name and arguments.\"\"\"\n\n\nclass UnparsedToolCall(StrictBase):\n    \"\"\"\n    Represents a tool call that failed to parse from model output.\n\n    When a model generates text that looks like a tool call but cannot be\n    parsed (e.g., invalid JSON), this class captures the raw text and error\n    for debugging and optional re-rendering.\n\n    Example:\n        unparsed = UnparsedToolCall(\n            raw_text='<tool_call>{\"name\": \"search\", invalid json}</tool_call>',\n            error=\"Invalid JSON: Expecting property name\"\n        )\n    \"\"\"\n\n    raw_text: str\n    \"\"\"The original text from the model that failed to parse.\"\"\"\n\n    error: str\n    \"\"\"Description of what went wrong during parsing.\"\"\"\n\n\nclass TextPart(TypedDict):\n    \"\"\"A chunk of text content in a message, usually meant to be visible to the user\n    (unlike ThinkingPart, which is internal reasoning).\"\"\"\n\n    type: Literal[\"text\"]\n    text: str\n\n\nclass ImagePart(TypedDict):\n    \"\"\"\n    A chunk of image content in a message.\n    \"\"\"\n\n    type: Literal[\"image\"]\n    image: str | Image.Image\n\n\nclass ThinkingPart(TypedDict):\n    \"\"\"Model's internal reasoning (chain-of-thought) as a content part.\"\"\"\n\n    type: Literal[\"thinking\"]\n    thinking: str  # The thinking/reasoning content\n\n\n# Container for a part of a multimodal message content.\n# Tool calls live exclusively in message[\"tool_calls\"] / message[\"unparsed_tool_calls\"].\nContentPart = TextPart | ImagePart | ThinkingPart\n\n\n# Streaming types to enable incremental parsing of model output for real-time display.\n\n\n@dataclass\nclass StreamingMessageHeader:\n    \"\"\"Emitted at the start of a new message during streaming.\n\n    This signals that a new message is beginning and provides the author info.\n    \"\"\"\n\n    role: str\n    name: str | None = None\n\n\n@dataclass\nclass StreamingTextDelta:\n    \"\"\"Incremental text content during streaming.\n\n    Contains only the new text since the last delta, not the accumulated text.\n    The recipient should concatenate deltas to build the full content.\n    \"\"\"\n\n    text: str\n    content_index: int = 0\n    \"\"\"Index of this content block within the message. Increments when content type changes.\"\"\"\n\n\n@dataclass\nclass StreamingThinkingDelta:\n    \"\"\"Incremental thinking/reasoning content during streaming.\n\n    Contains only the new thinking text since the last delta.\n    \"\"\"\n\n    thinking: str\n    content_index: int = 0\n    \"\"\"Index of this content block within the message. Increments when content type changes.\"\"\"\n\n\n# Union of all streaming update types.\n# A streaming parser yields these in sequence:\n# 1. StreamingMessageHeader (once at start)\n# 2. StreamingTextDelta / StreamingThinkingDelta (as content arrives)\n# 3. Message (once at end, containing the complete parsed message)\nMessageDelta = Union[StreamingMessageHeader, StreamingTextDelta, StreamingThinkingDelta, \"Message\"]\n\n\n# Unicode replacement character - indicates incomplete/invalid UTF-8 sequence\n_REPLACEMENT_CHAR = \"\\ufffd\"\n\n\n@dataclass\nclass Utf8TokenDecoder:\n    \"\"\"Handles incremental UTF-8 decoding from tokens.\n\n    Tokens can split multi-byte UTF-8 sequences (e.g., a 3-byte character\n    might be split across 2 tokens). This class buffers tokens until a\n    valid UTF-8 string can be decoded.\n\n    Detection strategy:\n    1. Try decoding all pending + new tokens\n    2. If result contains trailing U+FFFD (replacement char), it's incomplete\n    3. Scan backwards to find longest prefix without trailing replacement chars\n    4. Emit that prefix, buffer the rest\n\n    This handles tiktoken-style tokenizers that return replacement chars\n    instead of throwing exceptions for incomplete UTF-8.\n    \"\"\"\n\n    tokenizer: \"Tokenizer\"\n    _pending_tokens: list[int] = None  # type: ignore[assignment]\n\n    def __post_init__(self) -> None:\n        if self._pending_tokens is None:\n            self._pending_tokens = []\n\n    # Max tokens to try removing from the end when looking for decodable prefix.\n    # UTF-8 chars are max 4 bytes, tokens typically 1-4 bytes each,\n    # so 8 tokens is plenty to cover any incomplete trailing sequence.\n    _MAX_TRAILING_TOKENS_TO_TRY: int = 8\n\n    def _is_valid_decode(self, text: str) -> bool:\n        \"\"\"Check if decoded text represents a complete UTF-8 sequence.\n\n        Returns False if the text ends with a replacement character,\n        which indicates an incomplete multi-byte sequence that needs\n        more tokens to complete.\n        \"\"\"\n        return not text.endswith(_REPLACEMENT_CHAR)\n\n    def decode(self, tokens: list[int]) -> str | None:\n        \"\"\"Decode tokens to string, buffering incomplete UTF-8 sequences.\n\n        Args:\n            tokens: New tokens to decode.\n\n        Returns:\n            Decoded string if complete UTF-8 sequences are available,\n            None if all tokens were buffered (incomplete sequence).\n        \"\"\"\n        self._pending_tokens.extend(tokens)\n\n        # Try to decode all pending tokens (common case)\n        try:\n            text = str(self.tokenizer.decode(self._pending_tokens))\n            if self._is_valid_decode(text):\n                self._pending_tokens = []\n                return text\n            # Has trailing replacement chars - fall through to find valid prefix\n        except Exception:\n            pass\n\n        # Scan backwards to find longest decodable prefix without replacement chars.\n        # We only need to try removing a few tokens since UTF-8 sequences are at\n        # most 4 bytes and tokens are typically 1-4 bytes each.\n        for remove in range(\n            1, min(len(self._pending_tokens), self._MAX_TRAILING_TOKENS_TO_TRY) + 1\n        ):\n            prefix = self._pending_tokens[:-remove]\n            if not prefix:\n                break\n            try:\n                text = str(self.tokenizer.decode(prefix))\n                if self._is_valid_decode(text):\n                    self._pending_tokens = self._pending_tokens[-remove:]\n                    return text\n            except Exception:\n                continue\n\n        # All tokens buffered - need more data\n        return None\n\n    def flush(self) -> str:\n        \"\"\"Force decode any remaining tokens.\n\n        Call this at end of stream. May produce replacement characters\n        for incomplete sequences.\n        \"\"\"\n        if not self._pending_tokens:\n            return \"\"\n        try:\n            text = str(self.tokenizer.decode(self._pending_tokens))\n        except Exception:\n            # Last resort: decode with errors='replace' behavior\n            # Most tokenizers handle this, but fall back to empty string\n            text = \"\"\n        self._pending_tokens = []\n        return text\n\n    def reset(self) -> None:\n        \"\"\"Clear any buffered tokens.\"\"\"\n        self._pending_tokens = []\n\n    def has_pending(self) -> bool:\n        \"\"\"Check if there are buffered tokens waiting for more data.\"\"\"\n        return len(self._pending_tokens) > 0\n\n\n# =============================================================================\n# Streaming Parsers\n# =============================================================================\n\n\ndef _longest_matching_suffix_prefix(text: str, tag: str) -> int:\n    \"\"\"Find longest suffix of text that matches a prefix of tag.\n\n    This is used during streaming to determine how many characters at the end\n    of accumulated text might be the beginning of a tag, and thus shouldn't\n    be emitted yet.\n\n    Args:\n        text: The accumulated text to check.\n        tag: The tag we're looking for (e.g., \"<think>\").\n\n    Returns:\n        Length of the longest suffix of text that matches a prefix of tag.\n\n    Examples:\n        >>> _longest_matching_suffix_prefix(\"hello\", \"<think>\")\n        0  # no suffix matches any prefix\n        >>> _longest_matching_suffix_prefix(\"hello<\", \"<think>\")\n        1  # \"<\" matches prefix \"<\"\n        >>> _longest_matching_suffix_prefix(\"hello<th\", \"<think>\")\n        3  # \"<th\" matches prefix \"<th\"\n        >>> _longest_matching_suffix_prefix(\"hello<thx\", \"<think>\")\n        0  # \"<thx\" doesn't match any prefix of \"<think>\"\n    \"\"\"\n    max_check = min(len(text), len(tag) - 1)  # -1 because full tag would be found, not buffered\n    for length in range(max_check, 0, -1):\n        if text.endswith(tag[:length]):\n            return length\n    return 0\n\n\n@dataclass\nclass StreamingParser:\n    \"\"\"Base streaming parser for incremental token-to-delta conversion.\n\n    Handles the generic plumbing shared by all streaming parsers:\n    - Token-by-token feeding with end-token detection\n    - UTF-8 decoding across token boundaries\n    - Header emission on first content\n    - Final message construction via callback\n\n    Subclasses override ``_emit_deltas`` to implement model-specific parsing\n    (e.g., detecting ``<think>`` tags for reasoning models).\n\n    Usage::\n\n        parser = StreamingParser(tokenizer, end_token, parse_final_response)\n        for token in response_tokens:\n            for delta in parser.feed(token):\n                # Handle delta\n        for delta in parser.finish():\n            # Handle final deltas including complete Message\n    \"\"\"\n\n    tokenizer: \"Tokenizer\"\n    end_message_token: int\n    parse_final_response: Callable[[list[int]], tuple[\"Message\", bool]]\n\n    _utf8_decoder: Utf8TokenDecoder = field(init=False)\n    _accumulated_text: str = field(init=False, default=\"\")\n    _header_emitted: bool = field(init=False, default=False)\n    _content_index: int = field(init=False, default=0)\n    _last_emitted_pos: int = field(init=False, default=0)\n    _finished: bool = field(init=False, default=False)\n    _all_tokens: list[int] = field(init=False, default_factory=list)\n\n    def __post_init__(self) -> None:\n        self._utf8_decoder = Utf8TokenDecoder(self.tokenizer)\n        self._accumulated_text = \"\"\n        self._header_emitted = False\n        self._content_index = 0\n        self._last_emitted_pos = 0\n        self._finished = False\n        self._all_tokens = []\n\n    def feed(self, token: int) -> Iterator[\"MessageDelta\"]:\n        \"\"\"Feed a single token and yield any resulting deltas.\"\"\"\n        if self._finished:\n            return\n\n        self._all_tokens.append(token)\n\n        if token == self.end_message_token:\n            self._finished = True\n            return\n\n        decoded = self._utf8_decoder.decode([token])\n        if decoded is None:\n            return\n\n        self._accumulated_text += decoded\n\n        if not self._header_emitted:\n            self._header_emitted = True\n            yield StreamingMessageHeader(role=\"assistant\")\n\n        yield from self._emit_deltas()\n\n    def _emit_deltas(self) -> Iterator[\"MessageDelta\"]:\n        \"\"\"Emit deltas for any new content since last emission.\n\n        The base implementation emits all new text as StreamingTextDelta.\n        Subclasses override this to handle model-specific markup.\n        \"\"\"\n        text = self._accumulated_text\n        pos = self._last_emitted_pos\n        if pos < len(text):\n            new_text = text[pos:]\n            if new_text:\n                yield StreamingTextDelta(text=new_text, content_index=self._content_index)\n            self._last_emitted_pos = len(text)\n\n    def _emit_remaining(self) -> Iterator[\"MessageDelta\"]:\n        \"\"\"Emit any remaining buffered content at end of stream.\n\n        The base implementation emits remaining text as StreamingTextDelta.\n        Subclasses override this for type-aware emission (e.g., thinking vs text).\n        \"\"\"\n        text = self._accumulated_text\n        pos = self._last_emitted_pos\n        if pos < len(text):\n            remaining = text[pos:]\n            if remaining:\n                yield StreamingTextDelta(text=remaining, content_index=self._content_index)\n\n    def finish(self) -> Iterator[\"MessageDelta\"]:\n        \"\"\"Finish parsing and yield any remaining content plus final Message.\n\n        Call this after all tokens have been fed.\n        \"\"\"\n        remaining = self._utf8_decoder.flush()\n        if remaining:\n            self._accumulated_text += remaining\n\n        if not self._header_emitted:\n            self._header_emitted = True\n            yield StreamingMessageHeader(role=\"assistant\")\n\n        yield from self._emit_remaining()\n\n        message, _success = self.parse_final_response(self._all_tokens)\n        yield message\n\n    def reset(self) -> None:\n        \"\"\"Reset parser state for reuse.\"\"\"\n        self._utf8_decoder.reset()\n        self._accumulated_text = \"\"\n        self._header_emitted = False\n        self._content_index = 0\n        self._last_emitted_pos = 0\n        self._finished = False\n        self._all_tokens = []\n\n\n# Tags used by reasoning models (Qwen3, Kimi K2, DeepSeek, etc.)\n_THINK_OPEN_TAG = \"<think>\"\n_THINK_CLOSE_TAG = \"</think>\"\n\n\n@dataclass\nclass ReasoningStreamingParser(StreamingParser):\n    \"\"\"Streaming parser for models that use ``<think>...</think>`` reasoning blocks.\n\n    Extends StreamingParser with a state machine that detects ``<think>`` and\n    ``</think>`` tag boundaries, emitting StreamingThinkingDelta for reasoning\n    content and StreamingTextDelta for regular content. Handles partial tags\n    that may be split across token boundaries.\n\n    Used by renderers for Qwen3, Kimi K2, and other models that follow the\n    ``<think>...</think>`` convention for chain-of-thought reasoning.\n    \"\"\"\n\n    _in_thinking: bool = field(init=False, default=False)\n\n    def __post_init__(self) -> None:\n        super().__post_init__()\n        self._in_thinking = False\n\n    def _emit_deltas(self) -> Iterator[\"MessageDelta\"]:\n        \"\"\"Emit deltas with <think>/</think> tag awareness.\"\"\"\n        text = self._accumulated_text\n        pos = self._last_emitted_pos\n\n        while pos < len(text):\n            if not self._in_thinking:\n                # Look for <think> tag\n                think_start = text.find(_THINK_OPEN_TAG, pos)\n                if think_start == -1:\n                    # No <think> tag found - emit text up to a safe point.\n                    # Keep any trailing chars that could be the start of \"<think>\".\n                    suffix_from_pos = text[pos:]\n                    keep = _longest_matching_suffix_prefix(suffix_from_pos, _THINK_OPEN_TAG)\n                    safe_end = len(text) - keep\n                    if safe_end > pos:\n                        new_text = text[pos:safe_end]\n                        if new_text:\n                            yield StreamingTextDelta(\n                                text=new_text, content_index=self._content_index\n                            )\n                        self._last_emitted_pos = safe_end\n                    break\n                elif think_start > pos:\n                    # Emit text before <think>\n                    new_text = text[pos:think_start]\n                    if new_text:\n                        yield StreamingTextDelta(text=new_text, content_index=self._content_index)\n                    pos = think_start\n\n                if text[pos:].startswith(_THINK_OPEN_TAG):\n                    # Enter thinking mode\n                    self._in_thinking = True\n                    self._content_index += 1\n                    pos += len(_THINK_OPEN_TAG)\n                    self._last_emitted_pos = pos\n            else:\n                # In thinking mode - look for </think>\n                think_end = text.find(_THINK_CLOSE_TAG, pos)\n                if think_end == -1:\n                    # No </think> found - emit thinking up to safe point.\n                    # Keep any trailing chars that could be the start of \"</think>\".\n                    suffix_from_pos = text[pos:]\n                    keep = _longest_matching_suffix_prefix(suffix_from_pos, _THINK_CLOSE_TAG)\n                    safe_end = len(text) - keep\n                    if safe_end > pos:\n                        new_thinking = text[pos:safe_end]\n                        if new_thinking:\n                            yield StreamingThinkingDelta(\n                                thinking=new_thinking, content_index=self._content_index\n                            )\n                        self._last_emitted_pos = safe_end\n                    break\n                else:\n                    # Emit thinking before </think>\n                    new_thinking = text[pos:think_end]\n                    if new_thinking:\n                        yield StreamingThinkingDelta(\n                            thinking=new_thinking, content_index=self._content_index\n                        )\n                    # Exit thinking mode\n                    self._in_thinking = False\n                    self._content_index += 1\n                    pos = think_end + len(_THINK_CLOSE_TAG)\n                    self._last_emitted_pos = pos\n\n    def _emit_remaining(self) -> Iterator[\"MessageDelta\"]:\n        \"\"\"Emit remaining content, respecting thinking state.\"\"\"\n        text = self._accumulated_text\n        pos = self._last_emitted_pos\n        if pos < len(text):\n            remaining = text[pos:]\n            if self._in_thinking:\n                if remaining:\n                    yield StreamingThinkingDelta(\n                        thinking=remaining, content_index=self._content_index\n                    )\n            else:\n                if remaining:\n                    yield StreamingTextDelta(text=remaining, content_index=self._content_index)\n\n    def reset(self) -> None:\n        \"\"\"Reset parser state for reuse.\"\"\"\n        super().reset()\n        self._in_thinking = False\n\n\n# NOTE: we use a broad type definition for the role to be flexible\n# Common roles are \"user\", \"assistant\", \"system\", \"tool\"\nRole = str\n\n# Content is a string or a list of parts\nContent = str | list[ContentPart]\n\n\nclass Message(TypedDict):\n    \"\"\"\n    Container for a single turn in a multi-turn conversation.\n\n    Args:\n\n    role: Role\n        String that denotes the source of the message, typically system, user, assistant, and tool.\n    content: Content\n        Content of the message, can be a string, or a list of ContentPart.\n        When content is a list, it can contain TextPart, ImagePart, and ThinkingPart elements.\n        ThinkingPart represents the model's internal reasoning (chain-of-thought).\n    tool_calls: NotRequired[list[ToolCall]]\n        Optional sequence of successfully parsed tool calls generated by the model.\n    unparsed_tool_calls: NotRequired[list[UnparsedToolCall]]\n        Optional sequence of tool calls that failed to parse (e.g., invalid JSON).\n        The raw text is preserved for debugging or re-rendering.\n    trainable: NotRequired[bool]\n        Optional indicator whether this message should contribute to the training loss.\n    tool_call_id: NotRequired[str]\n        For tool result messages (role=\"tool\"): ID correlating this result to a specific\n        tool call. Used by renderers whose wire format references calls by ID (e.g., Kimi K2\n        renders \"## Return of {tool_call_id}\"). The value should match ToolCall.id from the\n        assistant's tool_calls. Not all formats use IDs - GptOss/Harmony does not.\n    name: NotRequired[str]\n        For tool result messages (role=\"tool\"): The function name that was called.\n        Required by GptOss (renders \"<|start|>functions.{name}...\"), optional for others.\n        When constructing tool results, include both name and tool_call_id when available\n        since different renderers require different fields.\n\n    \"\"\"\n\n    role: Role\n    content: Content\n\n    tool_calls: NotRequired[list[ToolCall]]\n    unparsed_tool_calls: NotRequired[list[\"UnparsedToolCall\"]]\n    trainable: NotRequired[bool]\n    tool_call_id: NotRequired[str]\n    name: NotRequired[str]\n\n\n@dataclass\nclass RenderContext:\n    \"\"\"\n    Context passed to render_message for rendering a single message.\n\n    This allows renderers to access information about the message's position\n    in the conversation without changing the render_message signature for\n    each new piece of context needed.\n    \"\"\"\n\n    idx: int\n    \"\"\"Index of the message in the conversation (0-based).\"\"\"\n\n    is_last: bool\n    \"\"\"Whether this is the last message in the conversation.\"\"\"\n\n    prev_message: Message | None = None\n    \"\"\"The previous message in the conversation, if any.\"\"\"\n\n    last_user_index: int = -1\n    \"\"\"Index of the last user message in the conversation. -1 if no user messages.\n\n    This is computed by the base build_generation_prompt/build_supervised_example\n    and used by renderers like Qwen3.5 that need to treat assistant messages\n    differently based on whether they come before or after the last user message.\n    \"\"\"\n\n\nclass ToolSpec(TypedDict):\n    \"\"\"\n    Tool specification following the OpenAI function calling format.\n\n    This represents a tool that can be called by the model, including its name,\n    description, and parameter schema.\n\n    Example:\n        tool_spec: ToolSpec = {\n            \"name\": \"get_weather\",\n            \"description\": \"Get the current weather for a location\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\"type\": \"string\", \"description\": \"City name\"},\n                },\n                \"required\": [\"location\"],\n            },\n        }\n    \"\"\"\n\n    name: str\n    \"\"\"The name of the tool.\"\"\"\n    description: str\n    \"\"\"A description of what the tool does.\"\"\"\n    parameters: dict\n    \"\"\"JSON Schema object describing the tool's parameters.\"\"\"\n\n\ndef ensure_text(content: Content) -> str:\n    \"\"\"\n    Assert that content is text-only and return it as a string.\n\n    Raises ValueError if content contains images or multiple parts.\n    Use this to validate that message content is text-only before\n    processing it in code paths that don't support multimodal content.\n    \"\"\"\n    if isinstance(content, str):\n        return content\n    if len(content) == 1 and content[0][\"type\"] == \"text\":\n        return content[0][\"text\"]\n    raise RendererError(f\"Expected text content, got multimodal content with {len(content)} parts\")\n\n\ndef ensure_list(content: Content) -> list[ContentPart]:\n    \"\"\"Normalize content to list form. Wraps string content in a TextPart.\"\"\"\n    if isinstance(content, str):\n        return [TextPart(type=\"text\", text=content)]\n    return content\n\n\ndef content_to_jsonable(content: Content) -> str | list[dict[str, Any]]:\n    \"\"\"Convert message content to a JSON-serializable structure.\"\"\"\n    if isinstance(content, str):\n        return content\n\n    result: list[dict[str, Any]] = []\n    for part in content:\n        if part[\"type\"] == \"text\":\n            result.append({\"type\": \"text\", \"text\": part[\"text\"]})\n        elif part[\"type\"] == \"thinking\":\n            result.append({\"type\": \"thinking\", \"thinking\": part[\"thinking\"]})\n        elif part[\"type\"] == \"image\":\n            image: str | Image.Image = part[\"image\"]\n            image_part: dict[str, Any] = {\"type\": \"image\"}\n            if isinstance(image, str):\n                image_part[\"image\"] = image\n            result.append(image_part)\n        else:\n            raise RendererError(f\"Unknown content part type: {part['type']}\")\n    return result\n\n\ndef message_to_jsonable(message: Message) -> dict[str, Any]:\n    \"\"\"Convert a Message TypedDict to a JSON-serializable dict without losing metadata.\"\"\"\n    result: dict[str, Any] = {\n        \"role\": message[\"role\"],\n        \"content\": content_to_jsonable(message[\"content\"]),\n    }\n    if \"tool_calls\" in message:\n        result[\"tool_calls\"] = [tc.model_dump(mode=\"json\") for tc in message[\"tool_calls\"]]\n    if \"unparsed_tool_calls\" in message:\n        result[\"unparsed_tool_calls\"] = [\n            tc.model_dump(mode=\"json\") for tc in message[\"unparsed_tool_calls\"]\n        ]\n    if \"trainable\" in message:\n        result[\"trainable\"] = message[\"trainable\"]\n    if \"tool_call_id\" in message:\n        result[\"tool_call_id\"] = message[\"tool_call_id\"]\n    if \"name\" in message:\n        result[\"name\"] = message[\"name\"]\n    return result\n\n\ndef remove_thinking(parts: list[ContentPart]) -> list[ContentPart]:\n    \"\"\"Filter out ThinkingPart elements from a content part list.\"\"\"\n    return [p for p in parts if p[\"type\"] != \"thinking\"]\n\n\ndef get_text_content(message: Message) -> str:\n    \"\"\"Extract text content from message, stripping thinking parts.\n\n    Use this after parse_response when you only need the text output,\n    ignoring any thinking/reasoning content.\n    \"\"\"\n    content = message[\"content\"]\n    if isinstance(content, str):\n        return content\n    return \"\".join(p[\"text\"] for p in content if p[\"type\"] == \"text\")\n\n\ndef format_content_as_string(content: Content, separator: str = \"\\n\") -> str:\n    \"\"\"Format message content as a string, preserving all part types.\n\n    Unlike get_text_content which only extracts text parts, this formats\n    all content parts (thinking, text) as a readable string.\n\n    This is useful for compatibility with APIs that expect string content\n    (e.g., OpenAI Chat Completions API), but we don't recommend it if you\n    need to ensure correctness - prefer working with structured content directly\n    and using build_generation_prompt to convert to tokens.\n\n    Args:\n        content: Message content (string or list of ContentPart).\n        separator: String to join parts with. Default is newline.\n\n    Returns:\n        Formatted string representation of all content parts.\n    \"\"\"\n    if isinstance(content, str):\n        return content\n\n    parts = []\n    for p in content:\n        if p[\"type\"] == \"thinking\":\n            parts.append(f\"<think>{p['thinking']}</think>\")\n        elif p[\"type\"] == \"text\":\n            parts.append(p[\"text\"])\n        else:\n            raise RendererError(f\"Unknown content part type: {p['type']}\")\n    return separator.join(parts)\n\n\ndef _parse_tool_call_json(tool_call_str: str, raw_text: str) -> ToolCall | UnparsedToolCall:\n    \"\"\"Parse tool call JSON. Returns UnparsedToolCall on failure.\"\"\"\n    try:\n        tool_call = json.loads(tool_call_str.strip())\n    except json.JSONDecodeError as e:\n        return UnparsedToolCall(raw_text=raw_text, error=f\"Invalid JSON: {e}\")\n\n    if not isinstance(tool_call, dict):\n        return UnparsedToolCall(raw_text=raw_text, error=\"Tool call is not a JSON object\")\n\n    name = tool_call.get(\"name\")\n    arguments = tool_call.get(\"arguments\")\n    tool_id = tool_call.get(\"id\")\n\n    if not isinstance(name, str):\n        return UnparsedToolCall(raw_text=raw_text, error=\"Missing or invalid 'name' field\")\n    if not isinstance(arguments, dict):\n        return UnparsedToolCall(raw_text=raw_text, error=\"Missing or invalid 'arguments' field\")\n\n    if tool_id is not None and not isinstance(tool_id, str):\n        tool_id = None\n\n    # TODO: arguments is already a dict from json.loads above, but ToolCall.FunctionBody.arguments\n    # expects a JSON string. This round-trip (loads then dumps) is wasteful. Consider changing\n    # FunctionBody.arguments to accept dict directly, or parse tool calls more lazily.\n    # We may want to revisit the decision to store arguments as unparsed JSON string.\n    return ToolCall(\n        function=ToolCall.FunctionBody(name=name, arguments=json.dumps(arguments)),\n        id=tool_id,\n    )\n\n\ndef parse_content_blocks(\n    content: str,\n) -> tuple[list[ContentPart], list[ToolCall | UnparsedToolCall]] | None:\n    \"\"\"\n    Parse a string with <think>...</think> and <tool_call>...</tool_call> tags.\n\n    Handles interleaved thinking, tool call, and text blocks. Content parts\n    (ThinkingPart, TextPart) are returned in the first element; tool calls\n    (ToolCall, UnparsedToolCall) are returned separately in the second element,\n    preserving their relative order.\n\n    Whitespace in non-tool-call regions is preserved exactly - roundtrip\n    (parse then render) is identity for the content parts.\n\n    Args:\n        content: String potentially containing <think> and/or <tool_call> blocks.\n\n    Returns:\n        Tuple of (content_parts, tool_calls), or None if no special tags are found.\n        content_parts contains only ThinkingPart/TextPart.\n        tool_calls contains ToolCall and UnparsedToolCall in order.\n\n    Example:\n        >>> parse_content_blocks(\"<think>step 1</think>answer<tool_call>{...}</tool_call>more\")\n        (\n            [ThinkingPart(type=\"thinking\", thinking=\"step 1\"),\n             TextPart(type=\"text\", text=\"answer\"),\n             TextPart(type=\"text\", text=\"more\")],\n            [ToolCall(...)],\n        )\n    \"\"\"\n    if \"<think>\" not in content and \"<tool_call>\" not in content:\n        return None  # No special blocks, caller should use original string\n\n    parts: list[ContentPart] = []\n    tool_calls: list[ToolCall | UnparsedToolCall] = []\n    pos = 0\n\n    # Pattern to find both <think>...</think> and <tool_call>...</tool_call> blocks\n    pattern = re.compile(r\"<think>(.*?)</think>|<tool_call>(.*?)</tool_call>\", re.DOTALL)\n\n    for match in pattern.finditer(content):\n        # Add any text before this block (preserve whitespace for identity roundtrip)\n        text_before = content[pos : match.start()]\n        if text_before:  # Skip only truly empty strings\n            parts.append(TextPart(type=\"text\", text=text_before))\n\n        if match.group(1) is not None:\n            # This is a <think> block\n            thinking = match.group(1)\n            if thinking:  # Skip empty thinking blocks\n                parts.append(ThinkingPart(type=\"thinking\", thinking=thinking))\n        else:\n            # This is a <tool_call> block — goes into separate tool_calls list\n            tool_call_json = match.group(2)\n            raw_text = match.group(0)  # Full match including tags\n            tool_calls.append(_parse_tool_call_json(tool_call_json, raw_text))\n\n        pos = match.end()\n\n    # Add any remaining text after the last block\n    remaining = content[pos:]\n    if remaining:  # Skip only truly empty strings\n        parts.append(TextPart(type=\"text\", text=remaining))\n\n    return parts, tool_calls\n\n\ndef parse_think_blocks(content: str) -> list[ContentPart] | None:\n    \"\"\"\n    Parse a string with only <think>...</think> tags into ThinkingPart/TextPart list.\n\n    This is a simpler version of parse_content_blocks for renderers that use\n    non-standard tool call formats (like DeepSeek's <｜tool▁calls▁begin｜>).\n\n    Whitespace is preserved exactly - roundtrip (parse then render) is identity.\n\n    Args:\n        content: String potentially containing <think>...</think> blocks.\n\n    Returns:\n        List of ThinkingPart and TextPart in order. None if no <think> tags found.\n    \"\"\"\n    if \"<think>\" not in content:\n        return None\n\n    parts: list[ContentPart] = []\n    pos = 0\n    pattern = re.compile(r\"<think>(.*?)</think>\", re.DOTALL)\n\n    for match in pattern.finditer(content):\n        text_before = content[pos : match.start()]\n        if text_before:  # Skip only truly empty strings\n            parts.append(TextPart(type=\"text\", text=text_before))\n\n        thinking = match.group(1)\n        if thinking:  # Skip empty thinking blocks\n            parts.append(ThinkingPart(type=\"thinking\", thinking=thinking))\n\n        pos = match.end()\n\n    remaining = content[pos:]\n    if remaining:  # Skip only truly empty strings\n        parts.append(TextPart(type=\"text\", text=remaining))\n\n    return parts\n\n\ndef _tool_call_payload(tool_call: ToolCall) -> dict[str, object]:\n    \"\"\"Minimal JSON payload for embedding in <tool_call> blocks.\"\"\"\n    # Convert from nested structure to flat format for compatibility\n    return {\n        \"name\": tool_call.function.name,\n        \"arguments\": json.loads(tool_call.function.arguments),\n    }\n\n\n@dataclass(frozen=True)\nclass RenderedMessage:\n    \"\"\"\n    Container for parts of a rendered message, structured for loss masking.\n\n    A rendered message is split into header and output to control which tokens receive\n    training loss. In the simplest case (where the full conversation is formed by\n    concatenation), building a supervised example from messages [m_0, ..., m_{n-1}]\n    produces:\n\n        tokens = BOS + header_0 + output_0 + header_1 + output_1 + ... + header_{n-1} + output_{n-1}\n\n    However, some renderers modify this structure. For example, Qwen3Renderer strips\n    thinking blocks from historical assistant messages. Such renderers must override\n    build_supervised_example to match their build_generation_prompt behavior.\n\n    Attributes:\n        output: What the model generates for this turn: the message text/images plus\n            end-of-turn tokens. This is the trainable portion.\n            Examples: \" Hello world\\\\\\\\n\\\\\\\\n\" (RoleColon), \"Hello world<|eot_id|>\" (Llama3).\n        header: Role identifier and delimiters that introduce the turn. This is what the\n            model sees but does not generate.\n            Examples: \"User:\" (RoleColon), \"<|start_header_id|>user<|end_header_id|>\\\\\\\\n\\\\\\\\n\" (Llama3).\n            Typically receives zero training weight.\n        stop_overlap: Edge case field for formats where the stop sequence spans message\n            boundaries. Most renderers (Llama3, Qwen3, DeepSeek, etc.) don't use this—their\n            stop tokens are included in output.\n\n            Only RoleColonRenderer uses this. Its stop sequence is \"\\\\\\\\n\\\\\\\\nUser:\", where \"\\\\\\\\n\\\\\\\\n\"\n            ends the output but \"User:\" would duplicate the next message's header. To avoid\n            duplication, \"User:\" is stored here and only appended for the last message in\n            supervised training. The name \"stop_overlap\" reflects that these tokens are the\n            overlap between the stop sequence and the next message's header.\n    \"\"\"\n\n    output: list[tinker.ModelInputChunk]\n    \"\"\"What the model generates for this turn.\"\"\"\n\n    header: tinker.EncodedTextChunk | None = None\n    \"\"\"Role identifier and delimiters that introduce the turn.\"\"\"\n\n    stop_overlap: tinker.EncodedTextChunk | None = None\n    \"\"\"Tokens that overlap between stop sequence and next message's header.\"\"\"\n\n\nclass TrainOnWhat(StrEnum):\n    LAST_ASSISTANT_MESSAGE = \"last_assistant_message\"\n    LAST_ASSISTANT_TURN = \"last_assistant_turn\"\n    ALL_ASSISTANT_MESSAGES = \"all_assistant_messages\"\n    ALL_MESSAGES = \"all_messages\"\n    ALL_TOKENS = \"all_tokens\"\n    ALL_USER_AND_SYSTEM_MESSAGES = \"all_user_and_system_messages\"\n    CUSTOMIZED = \"customized\"\n\n\ndef _unpickle_renderer(\n    renderer_name: str, model_name: str, has_image_processor: bool\n) -> \"Renderer\":\n    \"\"\"Reconstruct a Renderer from its name and model name.\n\n    Called by pickle to deserialize Renderer instances. Uses cached tokenizer/image_processor\n    so reconstruction cost is negligible after first call.\n    \"\"\"\n    from tinker_cookbook.renderers import get_renderer\n    from tinker_cookbook.tokenizer_utils import get_tokenizer\n\n    tokenizer = get_tokenizer(model_name)\n    image_processor = None\n    if has_image_processor:\n        from tinker_cookbook.image_processing_utils import get_image_processor\n\n        image_processor = get_image_processor(model_name)\n    return get_renderer(renderer_name, tokenizer, image_processor, model_name=model_name)\n\n\nclass Renderer(ABC):\n    \"\"\"\n    Abstract base class for rendering message lists into training and sampling prompts.\n\n    Subclasses must implement:\n    - get_stop_sequences(): Return stop tokens/strings for sampling\n    - render_message(): Break a message into header/output/stop_overlap components\n    - parse_response(): Convert sampled tokens back into a Message\n\n    The default build_generation_prompt and build_supervised_example implementations\n    assume simple concatenation of rendered messages. Override these if your renderer\n    modifies the conversation structure (e.g., stripping thinking blocks from history).\n\n    Pickle support: Renderers created via ``get_renderer()`` are automatically pickleable.\n    On deserialization, the tokenizer and image processor are reconstructed from cached\n    loaders, so the cost is negligible. Renderers created directly (not via ``get_renderer()``)\n    must set ``_renderer_name`` and ``_model_name`` manually to be pickleable.\n\n    Implementations of ``EnvGroupBuilder`` must be pickleable to support distributed rollout\n    execution. Since many builders store a Renderer, this pickle support is critical.\n    \"\"\"\n\n    tokenizer: Tokenizer\n\n    # Pickle metadata — set by get_renderer() via _stamp_pickle_metadata().\n    # Class-level defaults ensure these exist even when subclasses bypass super().__init__().\n    _renderer_name: str | None = None\n    _model_name: str | None = None\n    _has_image_processor: bool = False\n\n    def __init__(self, tokenizer: Tokenizer):\n        self.tokenizer = tokenizer\n\n    def __reduce__(self) -> tuple:\n        \"\"\"Enable pickling by storing only (renderer_name, model_name, has_image_processor).\n\n        On unpickling, the Renderer is reconstructed via get_renderer() with a\n        cached tokenizer, so the cost is negligible.\n        \"\"\"\n        renderer_name = getattr(self, \"_renderer_name\", None)\n        model_name = getattr(self, \"_model_name\", None)\n        has_image_processor = getattr(self, \"_has_image_processor\", False)\n        if renderer_name is None or model_name is None:\n            raise pickle.PicklingError(\n                f\"Cannot pickle {type(self).__name__}: _renderer_name or _model_name not set. \"\n                \"Renderers must be created via get_renderer() to be pickleable, \"\n                \"or set _renderer_name and _model_name manually.\"\n            )\n        return (\n            _unpickle_renderer,\n            (renderer_name, model_name, has_image_processor),\n        )\n\n    @property\n    def has_extension_property(self) -> bool:\n        \"\"\"Whether this renderer satisfies the sequence extension property.\n\n        A renderer has the extension property if, for any multi-turn conversation,\n        calling build_generation_prompt at each successive assistant turn produces\n        token sequences where each is a prefix of the next. This enables:\n        - Merging multiple timesteps into a single training datum\n        - KV-cache reuse during sampling\n        - O(T) compute scaling instead of O(T^2) for T-turn trajectories\n\n        Renderers that strip thinking blocks from history (like Qwen3Renderer with\n        strip_thinking_from_history=True) do NOT have this property because the\n        observation at timestep 2 is not a prefix of timestep 1's full sequence.\n\n        See docs/rl/sequence-extension.mdx for details.\n        \"\"\"\n        return False\n\n    @property\n    def _bos_tokens(self) -> list[int]:\n        return []\n\n    @abstractmethod\n    def get_stop_sequences(self) -> list[str] | list[int]:\n        \"\"\"Return the stop sequences used when sampling from this renderer.\"\"\"\n        ...\n\n    @abstractmethod\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        \"\"\"\n        Render a single message into its header/output/stop_overlap components.\n\n        This method breaks down a message into parts for loss masking. See RenderedMessage\n        for detailed semantics of each component.\n\n        Args:\n            message: The message to render.\n            ctx: Context about the message's position in the conversation, including:\n                - idx: The index of this message (0-based)\n                - is_last: Whether this is the last message\n                - prev_message: The previous message, if any\n\n        Returns:\n            RenderedMessage with header, output, and optionally stop_overlap.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def parse_response(self, response: list[int]) -> tuple[Message, bool]:\n        \"\"\"\n        Parse sampled tokens back into a Message.\n\n        Args:\n            response: Token IDs returned from sampling.\n\n        Returns:\n            A tuple of (message, success). If success is False, the response could not\n            be parsed (e.g., missing stop token), but a best-effort message is still returned.\n        \"\"\"\n        ...\n\n    supports_streaming: bool = False\n    \"\"\"Whether this renderer supports streaming response parsing.\n\n    Renderers that set this to True get a default parse_response_streaming\n    implementation using ReasoningStreamingParser. They must also define\n    ``_end_message_token`` and ``_parse_response_for_streaming``.\n    \"\"\"\n\n    def _normalize_response_tokens(self, response: list[int]) -> list[int]:\n        \"\"\"Normalize sampled response tokens before parsing.\n\n        Subclasses that prefill tokens in build_generation_prompt (e.g. <think>)\n        should override this to restore the prefilled tokens so that parse_response\n        and parse_response_streaming see a complete token sequence.\n\n        The default implementation is the identity function.\n        \"\"\"\n        return response\n\n    @property\n    def _end_message_token(self) -> int:\n        \"\"\"The token ID that marks the end of a message.\n\n        Must be overridden by subclasses that set supports_streaming = True.\n        \"\"\"\n        raise NotImplementedError(\n            f\"{type(self).__name__} must define _end_message_token to support streaming\"\n        )\n\n    def _parse_response_for_streaming(self, response: list[int]) -> tuple[Message, bool]:\n        \"\"\"Parse response for streaming, always applying full content parsing.\n\n        Unlike parse_response which may short-circuit on missing stop token,\n        this always parses content blocks from the response. This ensures\n        the final Message emitted by streaming is complete even for truncated\n        responses.\n\n        The default delegates to parse_response. Subclasses should override\n        if their parse_response short-circuits on missing stop token.\n        \"\"\"\n        return self.parse_response(response)\n\n    def parse_response_streaming(self, response: list[int]) -> Iterator[MessageDelta]:\n        \"\"\"Parse response tokens with streaming, yielding incremental deltas.\n\n        This enables real-time display of model output by yielding partial\n        content as tokens arrive, rather than waiting for the complete response.\n\n        Renderers that set ``supports_streaming = True`` get a default\n        implementation using ReasoningStreamingParser. Others raise\n        NotImplementedError.\n\n        Args:\n            response: Token IDs from the model.\n\n        Yields:\n            StreamingMessageHeader: Once at the start of the message.\n            StreamingTextDelta: Incremental text content.\n            StreamingThinkingDelta: Incremental thinking/reasoning content.\n            Message: The complete parsed message at the end.\n        \"\"\"\n        if not self.supports_streaming:\n            raise NotImplementedError(\n                f\"{type(self).__name__} does not support streaming response parsing\"\n            )\n        response = self._normalize_response_tokens(response)\n        parser = ReasoningStreamingParser(\n            tokenizer=self.tokenizer,\n            end_message_token=self._end_message_token,\n            parse_final_response=self._parse_response_for_streaming,\n        )\n        for token in response:\n            yield from parser.feed(token)\n        yield from parser.finish()\n\n    def to_openai_message(self, message: Message) -> dict:\n        \"\"\"\n        Convert a Message to OpenAI chat completions API format.\n\n        The returned object can be passed into the transformers library's\n        apply_chat_template function, which is useful for testing purposes.\n\n        It's also useful for querying models that are being served through\n        OpenAI-compatible APIs (OpenRouter, vLLM, etc.).\n\n        The base implementation handles:\n        - Basic role/content conversion\n        - tool_calls conversion from ToolCall objects to OpenAI dict format\n        - tool_call_id and name for tool response messages\n\n        Subclasses should override this to handle model-specific features like\n        reasoning_content for thinking models.\n\n        Args:\n            message: The Message to convert.\n\n        Returns:\n            A dict in OpenAI API message format.\n        \"\"\"\n        result: dict = {\"role\": message[\"role\"]}\n\n        # Handle content\n        content = message[\"content\"]\n        if isinstance(content, str):\n            result[\"content\"] = content\n        else:\n            # Structured content with ThinkingPart/TextPart/etc.\n            # Base implementation: concatenate text parts, render thinking as <think> tags\n            # TODO: Add proper support for ImagePart by converting to OpenAI-style content parts\n            # (list of {\"type\": \"image_url\", \"image_url\": {...}} dicts)\n            parts = []\n            for p in content:\n                if p[\"type\"] == \"text\":\n                    parts.append(p[\"text\"])\n                elif p[\"type\"] == \"thinking\":\n                    parts.append(f\"<think>{p['thinking']}</think>\")\n                elif p[\"type\"] == \"image\":\n                    raise NotImplementedError(\n                        \"to_openai_message does not support ImagePart content. \"\n                        \"Images would be silently dropped, leading to incorrect HF template \"\n                        \"comparisons or OpenAI API calls. Use build_generation_prompt for VL models.\"\n                    )\n            result[\"content\"] = \"\".join(parts)\n\n        # Handle tool_calls (convert ToolCall objects to OpenAI format)\n        if \"tool_calls\" in message and message[\"tool_calls\"]:  # noqa: RUF019\n            result[\"tool_calls\"] = [\n                {\n                    \"type\": \"function\",\n                    \"id\": tc.id,\n                    \"function\": {\n                        \"name\": tc.function.name,\n                        \"arguments\": tc.function.arguments,\n                    },\n                }\n                for tc in message[\"tool_calls\"]\n            ]\n\n        # Handle tool response fields\n        if message[\"role\"] == \"tool\":\n            if \"tool_call_id\" in message:\n                result[\"tool_call_id\"] = message[\"tool_call_id\"]\n            if \"name\" in message:\n                result[\"name\"] = message[\"name\"]\n\n        return result\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        \"\"\"Create message(s) with tool specifications to prepend to conversations.\n\n        Returns one or more messages to prepend to the conversation. This is the\n        standard way to add tools - the returned messages should be placed at the\n        start of your message list before user/assistant messages.\n\n        Args:\n            tools: List of tool specifications.\n            system_prompt: The system prompt content.\n\n        Returns:\n            List of messages to prepend to the conversation.\n\n        Raises:\n            NotImplementedError: If the renderer doesn't support tool calling.\n        \"\"\"\n        raise NotImplementedError\n\n    def _get_generation_suffix(self, role: Role, ctx: RenderContext) -> list[int]:\n        \"\"\"Return tokens to append to the prompt for generation.\n\n        This is called by build_generation_prompt to add the role header that\n        precedes the model's response. The default implementation renders an\n        empty message and extracts its header tokens.\n\n        Args:\n            role: The role to generate (usually \"assistant\")\n            ctx: Context for the generation suffix. Note that ctx.is_last is True\n                because we're rendering the header for the final (to-be-generated) message.\n\n        Returns:\n            List of token IDs for the role header. Examples in string form:\n            - Llama3: \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            - Qwen3: \"<|im_start|>assistant\\n\"\n            - DeepSeek: \"<｜Assistant｜>\" (single special token)\n        \"\"\"\n        # Default: render an empty message and use its header tokens\n        rendered = self.render_message(Message(role=role, content=\"\"), ctx)\n        if rendered.header:\n            return list(rendered.header.tokens)\n        return []\n\n    def build_generation_prompt(\n        self, messages: list[Message], role: Role = \"assistant\", prefill: str | None = None\n    ) -> tinker.ModelInput:\n        \"\"\"\n        Generates tokens for sampling from the model.\n\n        Args:\n            messages: a list of messages to render.\n            role: the role of the partial message to be completed.\n            prefill: an optional string to prefill in the model's generation.\n        \"\"\"\n\n        chunks: list[tinker.types.ModelInputChunk] = []\n        if self._bos_tokens:\n            chunks.append(tinker.types.EncodedTextChunk(tokens=self._bos_tokens))\n\n        last_user_idx = max(\n            (idx for idx, message in enumerate(messages) if message[\"role\"] == \"user\"),\n            default=-1,\n        )\n\n        for idx, message in enumerate(messages):\n            ctx = RenderContext(\n                idx=idx,\n                is_last=(idx == len(messages) - 1),\n                prev_message=messages[idx - 1] if idx > 0 else None,\n                last_user_index=last_user_idx,\n            )\n            rendered_message = self.render_message(message, ctx)\n            header_chunk = rendered_message.header\n            output_chunks = rendered_message.output\n            if header_chunk:\n                chunks.append(header_chunk)\n            # Filter out empty EncodedTextChunks, which cause 400 errors in model requests\n            chunks.extend(\n                [x for x in output_chunks if not isinstance(x, tinker.EncodedTextChunk) or x.tokens]\n            )\n\n        suffix_ctx = RenderContext(\n            idx=len(messages),\n            is_last=True,\n            prev_message=messages[-1] if messages else None,\n            last_user_index=last_user_idx,\n        )\n        suffix_tokens = self._get_generation_suffix(role, suffix_ctx)\n        if suffix_tokens:\n            chunks.append(tinker.types.EncodedTextChunk(tokens=suffix_tokens))\n\n        if prefill:\n            chunks.append(\n                tinker.types.EncodedTextChunk(\n                    tokens=self.tokenizer.encode(prefill, add_special_tokens=False)\n                )\n            )\n        return tinker.ModelInput(chunks=chunks)\n\n    def build_supervised_examples(\n        self,\n        messages: list[Message],\n        train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_TURN,\n    ) -> list[tuple[tinker.ModelInput, torch.Tensor]]:\n        \"\"\"\n        Build tokens and per-token weights for supervised fine-tuning.\n        This function returns a list of examples in the form of tuples, where each tuple contains a model input and a tensor of weights.\n        This is needed because some renderers do not satisfy the extension property, so we need to return a list of examples instead of a single example.\n\n        This default implementation concatenates rendered messages in order, which assumes the renderer satisfies the extension property.\n        Override this method if your renderer does not satisfy the extension property.\n        \"\"\"\n\n        if self.has_extension_property:\n            return [self.build_supervised_example(messages, train_on_what=train_on_what)]\n        else:\n            # TODO: Add a default implementation that calls `build_supervised_example` for each message and merges examples with shared prefixes.\n            raise NotImplementedError(\n                \"build_supervised_examples has not been implemented for this renderer.\"\n            )\n\n    def build_supervised_example(\n        self,\n        messages: list[Message],\n        train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE,\n    ) -> tuple[tinker.ModelInput, torch.Tensor]:\n        \"\"\"\n        Build tokens and per-token weights for supervised fine-tuning.\n\n        This default implementation concatenates rendered messages in order. Override\n        this method if your build_generation_prompt does anything that breaks the simple\n        concatenation assumption—for example, if it strips thinking blocks from history\n        (like Qwen3Renderer), injects default system prompts (like KimiK2Renderer), or\n        otherwise modifies the token sequence.\n\n        The supervised example tokens should match what build_generation_prompt would\n        produce for the same conversation prefix, so the model trains on the same\n        distribution it sees at inference time.\n\n        Args:\n            messages: A list of messages to render.\n            train_on_what: Controls which tokens receive non-zero training weight:\n                - LAST_ASSISTANT_MESSAGE: Only the last assistant message\n                - LAST_ASSISTANT_TURN: The last assistant message after the last user message\n                - ALL_ASSISTANT_MESSAGES: All assistant messages\n                - ALL_MESSAGES: All messages (but not headers)\n                - ALL_TOKENS: Everything including headers\n                - ALL_USER_AND_SYSTEM_MESSAGES: User and system messages only\n                - CUSTOMIZED: Use the 'trainable' field on each message\n\n        Returns:\n            A tuple of (model_input, weights) where weights is a 1D tensor with the\n            same length as the total number of tokens.\n        \"\"\"\n        # Warn if training on multiple assistant messages with a renderer that doesn't\n        # satisfy the extension property. In that case, each assistant message sees a\n        # different context prefix, so they should be trained as separate examples.\n        # NOTE: This warning only covers ALL_ASSISTANT_MESSAGES. Other modes that train\n        # multiple assistant messages (e.g., ALL_MESSAGES, ALL_TOKENS, CUSTOMIZED) should\n        # be used with caution when has_extension_property=False.\n        if train_on_what == TrainOnWhat.ALL_ASSISTANT_MESSAGES and not self.has_extension_property:\n            logger.warning(\n                \"WARNING: Using train_on_what=ALL_ASSISTANT_MESSAGES with a renderer that \"\n                \"does not satisfy the extension property (has_extension_property=False). \"\n                \"This means earlier assistant messages in the conversation see a different \"\n                \"token prefix than what build_generation_prompt would produce at that turn. \"\n                \"You should instead create separate conversations for each assistant message \"\n                \"and call build_supervised_example with train_on_what=LAST_ASSISTANT_MESSAGE \"\n                \"for each one. See docs/rl/sequence-extension.mdx for details.\"\n            )\n\n        model_input_chunks_weights: list[tuple[tinker.types.ModelInputChunk, float]] = []\n        if self._bos_tokens:\n            model_input_chunks_weights.append(\n                (tinker.types.EncodedTextChunk(tokens=self._bos_tokens), 0.0)\n            )\n\n        last_user_idx = max(\n            (idx for idx, message in enumerate(messages) if message[\"role\"] == \"user\"),\n            default=-1,\n        )\n\n        for idx, message in enumerate(messages):\n            if train_on_what == TrainOnWhat.CUSTOMIZED:\n                assert \"trainable\" in message, (\n                    \"When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise\"\n                )\n            else:\n                assert \"trainable\" not in message, (\n                    \"When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message\"\n                )\n\n            is_last_message = idx == len(messages) - 1\n            is_assistant = message[\"role\"] == \"assistant\"\n            is_user_or_system = message[\"role\"] in [\"user\", \"system\"]\n            is_after_last_user = last_user_idx == -1 or idx > last_user_idx\n\n            # only apply weight to header if train_on_what is ALL_TOKENS\n            ctx = RenderContext(\n                idx=idx,\n                is_last=is_last_message,\n                prev_message=messages[idx - 1] if idx > 0 else None,\n                last_user_index=last_user_idx,\n            )\n            rendered_message = self.render_message(message, ctx)\n            header_part = rendered_message.header\n            output_parts = rendered_message.output\n            stop_overlap_part = rendered_message.stop_overlap\n\n            header_weight = int(train_on_what == TrainOnWhat.ALL_TOKENS)\n            if header_part:\n                model_input_chunks_weights += [(header_part, header_weight)]\n\n            match train_on_what:\n                case TrainOnWhat.LAST_ASSISTANT_MESSAGE:\n                    output_has_weight = is_last_message and is_assistant\n                case TrainOnWhat.LAST_ASSISTANT_TURN:\n                    output_has_weight = is_assistant and is_after_last_user\n                case TrainOnWhat.ALL_ASSISTANT_MESSAGES:\n                    output_has_weight = is_assistant\n                case TrainOnWhat.ALL_MESSAGES:\n                    output_has_weight = True\n                case TrainOnWhat.ALL_TOKENS:\n                    output_has_weight = True\n                case TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES:\n                    output_has_weight = is_user_or_system\n                case TrainOnWhat.CUSTOMIZED:\n                    output_has_weight = message.get(\"trainable\", False)\n                case _:\n                    raise RendererError(f\"Unknown train_on_what: {train_on_what}\")\n\n            model_input_chunks_weights += [\n                (output_part, int(output_has_weight)) for output_part in output_parts if output_part\n            ]\n\n            # stop_overlap completes the stop sequence for formats like RoleColon (e.g., \"User:\")\n            # Only included for the last message.\n            if is_last_message and stop_overlap_part:\n                model_input_chunks_weights += [(stop_overlap_part, int(output_has_weight))]\n\n        weights_data = [w for chunk, w in model_input_chunks_weights for _ in range(chunk.length)]\n        weights_tensor = torch.tensor(weights_data)\n\n        model_input_chunks = [chunk for chunk, _ in model_input_chunks_weights]\n        return tinker.ModelInput(chunks=model_input_chunks), weights_tensor\n\n\ndef tokens_weights_from_strings_weights(\n    strings_weights: list[tuple[str, float]],\n    tokenizer: Tokenizer,\n) -> tuple[torch.Tensor, torch.Tensor]:\n    strings, weights = zip(*strings_weights, strict=True)\n    token_chunks = [tokenizer.encode(s, add_special_tokens=i == 0) for i, s in enumerate(strings)]\n    weights = torch.cat(\n        [torch.full((len(chunk),), w) for chunk, w in zip(token_chunks, weights, strict=True)]\n    )\n    tokens = torch.cat([torch.tensor(chunk) for chunk in token_chunks])\n    assert tokens.dtype == torch.int64\n    return tokens, weights\n\n\ndef parse_response_for_stop_token(\n    response: list[int], tokenizer: Tokenizer, stop_token: int\n) -> tuple[Message, bool]:\n    \"\"\"Parse response for a single stop token.\n\n    We expect a properly rendered response to have exactly one stop token; but it may have zero if e.g. the model\n    ran out of tokens when sampling, which will incur a format error. If there are > 1, there is likely a bug in the\n    sampler and we should error.\n    \"\"\"\n    emt_count = response.count(stop_token)\n    if emt_count == 0:\n        str_response = str(tokenizer.decode(response))\n        logger.debug(f\"Response is not a valid assistant response: {str_response}\")\n        return Message(role=\"assistant\", content=str_response), False\n    elif emt_count == 1:\n        str_response = str(tokenizer.decode(response[: response.index(stop_token)]))\n        return Message(role=\"assistant\", content=str_response), True\n    else:\n        raise RendererError(\n            f\"When parsing response, expected to split into 1 or 2 pieces using stop tokens, but got {emt_count}. \"\n            \"You probably are using the wrong stop tokens when sampling\"\n        )\n\n\n# Image processing utilities (used by VL renderers)\n\n\nclass ImageProcessorProtocol(Protocol):\n    merge_size: int\n    patch_size: int\n\n    def get_number_of_image_patches(\n        self, height: int, width: int, images_kwargs: dict | None = None\n    ) -> int:\n        raise NotImplementedError()\n\n    def get_resize_config(self, image_data: dict[str, Any]) -> dict[str, Any]:\n        raise NotImplementedError()\n\n\ndef image_to_chunk(\n    image_or_str: Image.Image | str, image_processor: ImageProcessorProtocol\n) -> tinker.types.ImageChunk:\n    \"\"\"\n    Convert a PIL Image to a tinker.types.ImageChunk for QwenVL\n    \"\"\"\n\n    # load an image from a data URI or a URL\n    if isinstance(image_or_str, str):\n        with urllib.request.urlopen(image_or_str) as response:\n            pil_image = Image.open(io.BytesIO(response.read()))\n\n    # Otherwise the image is a PIL image and can be loaded directly\n    elif isinstance(image_or_str, Image.Image):\n        pil_image = image_or_str\n\n    # Validate the provided data is actually a valid image type\n    else:\n        raise RendererError(\"The provided image must be a PIL.Image.Image, URL, or data URI.\")\n\n    # Convert to RGB if needed (JPEG doesn't support RGBA/LA/P modes)\n    if pil_image.mode in (\"RGBA\", \"LA\", \"P\"):\n        pil_image = pil_image.convert(\"RGB\")\n\n    img_byte_arr = io.BytesIO()\n    pil_image.save(img_byte_arr, format=\"JPEG\")\n    image_data = img_byte_arr.getvalue()\n\n    # Get the number of expected tokens for the image. The way to do this is not consistent between\n    # image processors (qwen3vl supports get_number_of_image_patches, kimi2.5 doesn't but has get_resize_config)\n    if hasattr(image_processor, \"get_number_of_image_patches\"):\n        width, height = pil_image.size\n        num_image_tokens = (\n            image_processor.get_number_of_image_patches(height, width, images_kwargs={})\n            // image_processor.merge_size**2\n        )\n    elif hasattr(image_processor, \"get_resize_config\"):\n        config = image_processor.get_resize_config({\"type\": \"image\", \"image\": pil_image})\n        num_image_tokens = config[\"num_tokens\"]\n    else:\n        raise RendererError(\n            f\"Don't know how to get the number of image tokens for image processor: {image_processor}\"\n        )\n\n    return tinker.types.ImageChunk(\n        data=image_data,\n        format=\"jpeg\",\n        expected_tokens=num_image_tokens,\n    )\n"
  },
  {
    "path": "tinker_cookbook/renderers/deepseek_v3.py",
    "content": "\"\"\"\nDeepSeek V3 family renderers.\n\nIncludes:\n- DeepSeekV3ThinkingRenderer: V3 models in thinking mode\n- DeepSeekV3DisableThinkingRenderer: V3 models with thinking disabled\n\"\"\"\n\nimport json\nimport re\nimport warnings\n\nimport tinker\nimport transformers\n\nfrom tinker_cookbook.exceptions import RendererError\nfrom tinker_cookbook.renderers.base import (\n    Message,\n    RenderContext,\n    RenderedMessage,\n    Renderer,\n    ToolCall,\n    ToolSpec,\n    UnparsedToolCall,\n    ensure_text,\n    parse_response_for_stop_token,\n    parse_think_blocks,\n)\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n\nclass _DeepSeekV3BaseRenderer(Renderer):\n    \"\"\"\n    Base renderer for DeepSeek V3 models with common rendering logic.\n\n    This is a private base class. Use DeepSeekV3ThinkingRenderer or\n    DeepSeekV3DisableThinkingRenderer instead.\n\n    System messages at position 0 are rendered without role tokens (matching HF template).\n    System messages at later positions require system_role_as_user=True to convert to user role.\n\n    The default strip_thinking_from_history=True matches HF behavior where thinking\n    traces are removed from historical assistant messages in multi-turn conversations.\n    Use strip_thinking_from_history=False for multi-turn RL to get the extension property.\n    \"\"\"\n\n    supports_streaming = True\n\n    def __init__(\n        self,\n        tokenizer: Tokenizer,\n        system_role_as_user: bool = False,\n        strip_thinking_from_history: bool = True,\n    ):\n        super().__init__(tokenizer)\n        self.system_role_as_user = system_role_as_user\n        self.strip_thinking_from_history = strip_thinking_from_history\n\n        if transformers.__version__ == \"5.3.0\":\n            warnings.warn(\n                \"transformers 5.3.0 has a known bug with the DeepSeek tokenizer that \"\n                \"strips spaces during decode, which will produce incorrect outputs. \"\n                \"Please upgrade to transformers>=5.3.1 or downgrade to transformers<5.3.0. \"\n                \"See https://github.com/huggingface/transformers/pull/44801\",\n                stacklevel=2,\n            )\n\n    @property\n    def has_extension_property(self) -> bool:\n        \"\"\"Extension property depends on strip_thinking_from_history setting.\n\n        When strip_thinking_from_history=False, thinking traces are preserved in\n        history, so each successive observation is a prefix extension of the previous.\n\n        When strip_thinking_from_history=True (default), thinking traces are stripped\n        from historical messages, breaking the extension property.\n        \"\"\"\n        return not self.strip_thinking_from_history\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        \"\"\"Render a single message to tokens.\n\n        Args:\n            message: The message to render.\n            ctx: Context about the message's position, including:\n                - idx: The index of this message (0-based)\n                - is_last: Whether this is the last message (affects thinking stripping)\n                - prev_message: The previous message, used to detect post-tool formatting\n        \"\"\"\n        # Check if this assistant message follows a tool response\n        follows_tool = ctx.prev_message is not None and ctx.prev_message[\"role\"] == \"tool\"\n\n        content = message[\"content\"]\n\n        if message[\"role\"] == \"system\":\n            # HF template collects all system messages at the start without role tokens\n            # We only support this for idx=0; later system messages need system_role_as_user=True\n            content_str = ensure_text(content)\n            if ctx.idx == 0:\n                header_tokens: list[int] = []\n                output_str = content_str\n            elif self.system_role_as_user:\n                # Convert later system messages to user role\n                role_token = self._get_special_token(\"User\")\n                header_tokens = [role_token]\n                output_str = content_str\n            else:\n                raise RendererError(\n                    \"DeepSeek only supports system message at start. \"\n                    \"Use system_role_as_user=True to convert later system messages to user role.\"\n                )\n        elif message[\"role\"] == \"user\":\n            role_token = self._get_special_token(\"User\")\n            header_tokens = [role_token]\n            output_str = ensure_text(content)\n        elif message[\"role\"] == \"assistant\":\n            has_tool_calls = \"tool_calls\" in message and message[\"tool_calls\"]\n\n            # Determine if we should strip thinking content from this message\n            should_strip_thinking = (\n                self.strip_thinking_from_history and not has_tool_calls and not ctx.is_last\n            )\n\n            if isinstance(content, list):\n                # Structured content - handle with list operations\n                parts = content\n                # Render parts in order, preserving interleaved thinking/text structure.\n                # No separator needed - whitespace is preserved in TextPart for roundtrip identity.\n                rendered_parts = []\n                for p in parts:\n                    if p[\"type\"] == \"thinking\":\n                        if should_strip_thinking:\n                            # Skip thinking content entirely when stripping\n                            # (header gets </think> added separately to match HF format)\n                            pass\n                        else:\n                            rendered_parts.append(f\"<think>{p['thinking']}</think>\")\n                    elif p[\"type\"] == \"text\":\n                        rendered_parts.append(p[\"text\"])\n                output_content = \"\".join(rendered_parts)\n            else:\n                # String content - pass through as-is.\n                # Stripping only works with structured content (ThinkingPart).\n                output_content = content\n\n            if follows_tool:\n                # Post-tool assistant: no role token, content flows directly after tool output\n                header_tokens = []\n                output_str = output_content\n            else:\n                # Normal assistant message\n                role_token = self._get_special_token(\"Assistant\")\n                header_tokens = [role_token]\n                output_str = output_content\n        elif message[\"role\"] == \"tool\":\n            # Tool responses use special tool output tokens to match HF template\n            header_tokens = self.tokenizer.encode(\n                \"<｜tool▁output▁begin｜>\", add_special_tokens=False\n            )\n            output_str = ensure_text(content) + \"<｜tool▁output▁end｜>\"\n        else:\n            raise RendererError(f\"Unsupported role: {message['role']}\")\n\n        # Handle tool calls in assistant messages\n        # HF format: <｜tool▁calls▁begin｜><｜tool▁call▁begin｜>name<｜tool▁sep｜>args<｜tool▁call▁end｜><｜tool▁calls▁end｜>\n        if \"tool_calls\" in message and message[\"tool_calls\"]:  # noqa: RUF019\n            output_str += \"<｜tool▁calls▁begin｜>\"\n            for tool_call in message[\"tool_calls\"]:\n                func_name = tool_call.function.name\n                args = tool_call.function.arguments\n                output_str += (\n                    f\"<｜tool▁call▁begin｜>{func_name}<｜tool▁sep｜>{args}<｜tool▁call▁end｜>\"\n                )\n            output_str += \"<｜tool▁calls▁end｜>\"\n\n        output_tokens = self.tokenizer.encode(output_str, add_special_tokens=False)\n\n        # Add end_of_sentence only for assistant messages with content\n        # (not for empty generation prompt messages)\n        if message[\"role\"] == \"assistant\" and message[\"content\"]:\n            output_tokens.append(self._end_message_token)\n\n        output: list[tinker.ModelInputChunk] = [tinker.types.EncodedTextChunk(tokens=output_tokens)]\n        # Only include header if non-empty; tinker rejects empty token chunks with\n        # \"Chunk N has empty tokens list\". This happens for system messages at idx=0.\n        if header_tokens:\n            return RenderedMessage(\n                header=tinker.types.EncodedTextChunk(tokens=header_tokens), output=output\n            )\n        else:\n            return RenderedMessage(output=output)\n\n    def _get_special_token(self, name: str) -> int:\n        sep = chr(65372)\n        s = f\"<{sep}{name}{sep}>\"\n        res = self.tokenizer.encode(s, add_special_tokens=False)\n        assert len(res) == 1, f\"Expected single token for {s}, got {res}\"\n        return res[0]\n\n    @property\n    def _bos_tokens(self) -> list[int]:\n        return [self._get_special_token(\"begin▁of▁sentence\")]\n\n    @property\n    def _end_message_token(self) -> int:\n        return self._get_special_token(\"end▁of▁sentence\")\n\n    def get_stop_sequences(self) -> list[int]:\n        return [self._end_message_token]\n\n    def _parse_deepseek_tool_calls(\n        self, content: str\n    ) -> tuple[list[ToolCall], list[UnparsedToolCall]]:\n        \"\"\"Parse tool calls from DeepSeek V3.1 format.\n\n        Expected format (per HuggingFace model card and chat template):\n            <｜tool▁calls▁begin｜><｜tool▁call▁begin｜>func_name<｜tool▁sep｜>{\"arg\":\"value\"}<｜tool▁call▁end｜><｜tool▁calls▁end｜>\n\n        Multiple tool calls are chained directly without separators.\n\n        References:\n            - DeepSeek V3.1 Model Card: https://huggingface.co/deepseek-ai/DeepSeek-V3.1\n            - Chat Template: https://huggingface.co/deepseek-ai/DeepSeek-V3.1/blob/main/assets/chat_template.jinja\n        \"\"\"\n        tool_calls: list[ToolCall] = []\n        unparsed_tool_calls: list[UnparsedToolCall] = []\n\n        calls_match = re.search(\n            r\"<｜tool▁calls▁begin｜>(.*?)<｜tool▁calls▁end｜>\", content, re.DOTALL\n        )\n        if not calls_match:\n            return tool_calls, unparsed_tool_calls\n\n        for match in re.finditer(\n            r\"<｜tool▁call▁begin｜>(\\w+)<｜tool▁sep｜>(.*?)<｜tool▁call▁end｜>\",\n            calls_match.group(1),\n            re.DOTALL,\n        ):\n            raw_text = match.group(0)\n            func_name, args_str = match.group(1), match.group(2).strip()\n\n            try:\n                json.loads(args_str)\n                tool_calls.append(\n                    ToolCall(function=ToolCall.FunctionBody(name=func_name, arguments=args_str))\n                )\n            except json.JSONDecodeError as e:\n                unparsed_tool_calls.append(\n                    UnparsedToolCall(raw_text=raw_text, error=f\"Invalid JSON: {e}\")\n                )\n\n        return tool_calls, unparsed_tool_calls\n\n    def _parse_response_content(\n        self, response: list[int], *, allow_missing_stop: bool = False\n    ) -> tuple[Message, bool]:\n        \"\"\"Shared parsing logic for both batch and streaming paths.\n\n        Callers are responsible for normalization — this method does NOT call\n        ``_normalize_response_tokens``.\n        \"\"\"\n        assistant_message, parse_success = parse_response_for_stop_token(\n            response, self.tokenizer, self._end_message_token\n        )\n        if not parse_success and not allow_missing_stop:\n            return assistant_message, False\n\n        assert isinstance(assistant_message[\"content\"], str)\n        content = assistant_message[\"content\"]\n\n        # Parse DeepSeek-specific tool calls\n        tool_calls, unparsed_tool_calls = self._parse_deepseek_tool_calls(content)\n        if tool_calls:\n            assistant_message[\"tool_calls\"] = tool_calls\n        if unparsed_tool_calls:\n            assistant_message[\"unparsed_tool_calls\"] = unparsed_tool_calls\n\n        # Strip tool calls section from content (both parsed and unparsed)\n        if tool_calls or unparsed_tool_calls:\n            content = re.sub(\n                r\"\\s*<｜tool▁calls▁begin｜>.*?<｜tool▁calls▁end｜>\",\n                \"\",\n                content,\n                flags=re.DOTALL,\n            )\n            content = content.strip()\n\n        # Parse <think>...</think> blocks into ThinkingPart/TextPart list\n        parts = parse_think_blocks(content)\n        if parts is not None:\n            assistant_message[\"content\"] = parts\n        else:\n            assistant_message[\"content\"] = content\n\n        return assistant_message, parse_success\n\n    def parse_response(self, response: list[int]) -> tuple[Message, bool]:\n        response = self._normalize_response_tokens(response)\n        return self._parse_response_content(response, allow_missing_stop=False)\n\n    def _parse_response_for_streaming(self, response: list[int]) -> tuple[Message, bool]:\n        \"\"\"Parse response for streaming, always applying full content parsing.\n\n        Unlike parse_response which short-circuits on missing stop token,\n        this always parses think blocks and tool calls from the content.\n\n        Note: _normalize_response_tokens is NOT called here because\n        parse_response_streaming already normalizes before feeding tokens\n        to the parser.\n        \"\"\"\n        return self._parse_response_content(response, allow_missing_stop=True)\n\n    def to_openai_message(self, message: Message) -> dict:\n        \"\"\"Convert a Message to OpenAI API format with reasoning_content for thinking.\n\n        DeepSeek's API uses reasoning_content for thinking models, similar to OpenAI's o1.\n        \"\"\"\n        result: dict = {\"role\": message[\"role\"]}\n\n        content = message[\"content\"]\n        if isinstance(content, str):\n            result[\"content\"] = content\n        else:\n            # Extract thinking into reasoning_content, keep text in content\n            thinking_parts = []\n            text_parts = []\n            for p in content:\n                if p[\"type\"] == \"thinking\":\n                    thinking_parts.append(p[\"thinking\"])\n                elif p[\"type\"] == \"text\":\n                    text_parts.append(p[\"text\"])\n\n            result[\"content\"] = \"\".join(text_parts)\n            if thinking_parts:\n                result[\"reasoning_content\"] = \"\".join(thinking_parts)\n\n        # Handle tool_calls\n        if \"tool_calls\" in message and message[\"tool_calls\"]:  # noqa: RUF019\n            result[\"tool_calls\"] = [\n                {\n                    \"type\": \"function\",\n                    \"id\": tc.id,\n                    \"function\": {\n                        \"name\": tc.function.name,\n                        \"arguments\": tc.function.arguments,\n                    },\n                }\n                for tc in message[\"tool_calls\"]\n            ]\n\n        # Handle tool response fields\n        if message[\"role\"] == \"tool\":\n            if \"tool_call_id\" in message:\n                result[\"tool_call_id\"] = message[\"tool_call_id\"]\n            if \"name\" in message:\n                result[\"name\"] = message[\"name\"]\n\n        return result\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        \"\"\"Create system message with DeepSeek V3.1 tool specifications.\n\n        DeepSeek V3.1 tool calling requires tools to be described in the system message\n        using a specific format with ### headers and inline JSON parameters.\n\n        Note: Tool calling is supported in non-thinking mode only.\n\n        References:\n            - DeepSeek V3.1 Model Card (ToolCall section): https://huggingface.co/deepseek-ai/DeepSeek-V3.1\n            - DeepSeek V3.1 Chat Template: https://huggingface.co/deepseek-ai/DeepSeek-V3.1/blob/main/assets/chat_template.jinja\n            - DeepSeek API Tool Calls Guide: https://api-docs.deepseek.com/guides/tool_calls\n        \"\"\"\n        tools_text = \"\"\n        if tools:\n            # Format each tool with ### header, description, and parameters\n            tool_blocks = []\n            for tool in tools:\n                tool_block = f\"\"\"### {tool[\"name\"]}\nDescription: {tool[\"description\"]}\n\nParameters: {json.dumps(tool[\"parameters\"])}\"\"\"\n                tool_blocks.append(tool_block)\n\n            tools_text = f\"\"\"\n\n## Tools\nYou have access to the following tools:\n\n{chr(10).join(tool_blocks)}\n\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>tool_call_name<｜tool▁sep｜>tool_call_arguments<｜tool▁call▁end｜><｜tool▁calls▁end｜>\n\nWhere:\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\"\"\"\n\n        return [Message(role=\"system\", content=system_prompt + tools_text)]\n\n\nclass DeepSeekV3ThinkingRenderer(_DeepSeekV3BaseRenderer):\n    \"\"\"\n    Renderer for DeepSeek V3 models in THINKING mode.\n\n    Format:\n        <|begin_of_sentence|><|User|>question<|Assistant|><think>reasoning</think>answer<|end_of_sentence|>\n\n    For non-thinking mode, use DeepSeekV3DisableThinkingRenderer instead.\n\n    Generation prompts include <think> prefill to trigger thinking mode.\n    Think tags in message content come from ThinkPart rendering.\n\n    When strip_thinking_from_history=True (default), historical assistant messages\n    get </think> added to header and thinking content stripped, matching HF behavior.\n    \"\"\"\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        \"\"\"Render message, adding </think> to header when stripping thinking from history.\n\n        HF's thinking=True template uses </think> at the start of historical assistant\n        messages to signal \"we're past the thinking phase, here's the answer\".\n        \"\"\"\n        rendered = super().render_message(message, ctx)\n\n        # Add </think> to header for historical assistant messages when stripping thinking.\n        # This matches the base class's should_strip_thinking logic - only historical messages\n        # (not the last one) get </think> added. The last message is the supervised target and\n        # should preserve its format (including any ThinkingPart).\n        follows_tool = ctx.prev_message is not None and ctx.prev_message[\"role\"] == \"tool\"\n        should_add_think_close = (\n            message[\"role\"] == \"assistant\"\n            and not follows_tool\n            and self.strip_thinking_from_history\n            and not ctx.is_last\n        )\n\n        if should_add_think_close:\n            think_close_tokens = self.tokenizer.encode(\"</think>\", add_special_tokens=False)\n            old_header_tokens = list(rendered.header.tokens) if rendered.header else []\n            new_header = tinker.EncodedTextChunk(tokens=old_header_tokens + think_close_tokens)\n            rendered = RenderedMessage(header=new_header, output=rendered.output)\n\n        return rendered\n\n    def build_generation_prompt(\n        self,\n        messages: list[Message],\n        role: str = \"assistant\",\n        prefill: str | None = None,\n    ) -> tinker.ModelInput:\n        \"\"\"Build generation prompt with <think> prefill to trigger thinking mode.\n\n        Does NOT add <think> when the previous message is a tool response,\n        as tool-use conversations stay in non-thinking mode (matching HF behavior).\n        \"\"\"\n        # Don't add <think> prefill after tool responses - tool use is non-thinking mode\n        if messages and messages[-1][\"role\"] == \"tool\":\n            return super().build_generation_prompt(messages, role, prefill)\n\n        # Add <think> prefill to trigger thinking, combined with any user-provided prefill\n        think_prefill = \"<think>\" + (prefill or \"\")\n        return super().build_generation_prompt(messages, role, think_prefill)\n\n    def _normalize_response_tokens(self, response: list[int]) -> list[int]:\n        \"\"\"Restore the prefilled <think> token before parsing sampled tokens.\n\n        When sampling with build_generation_prompt, the <think> tag is part of the\n        prefill and not included in the sampled tokens. The response will be\n        \"reasoning</think>answer\" so we prepend <think> if necessary.\n        \"\"\"\n        think_prefix_token: int = self.tokenizer.convert_tokens_to_ids(\"<think>\")  # type: ignore[assignment]\n        think_suffix_token: int = self.tokenizer.convert_tokens_to_ids(\"</think>\")  # type: ignore[assignment]\n\n        starts_with_think = len(response) > 0 and response[0] == think_prefix_token\n        if not starts_with_think and think_suffix_token in response:\n            return [think_prefix_token] + response\n        return response\n\n\nclass DeepSeekV3DisableThinkingRenderer(_DeepSeekV3BaseRenderer):\n    \"\"\"\n    Renderer for DeepSeek V3 models in NON-THINKING mode.\n\n    Format:\n        <|begin_of_sentence|><|User|>question<|Assistant|></think>answer<|end_of_sentence|>\n\n    The </think> prefix signals to the model to skip reasoning and respond directly.\n    Any <think>...</think> blocks in the content are stripped.\n\n    For thinking mode, use DeepSeekV3ThinkingRenderer instead.\n    \"\"\"\n\n    @property\n    def has_extension_property(self) -> bool:\n        \"\"\"Non-thinking mode always satisfies extension - no thinking to strip from history.\"\"\"\n        return True\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        \"\"\"Render message in non-thinking mode.\n\n        For assistant messages (not following tool):\n        - Strip any ThinkingPart from structured content\n        - Add </think> to header to signal non-thinking mode\n        \"\"\"\n        # Check if this assistant message follows a tool response\n        follows_tool = ctx.prev_message is not None and ctx.prev_message[\"role\"] == \"tool\"\n\n        if message[\"role\"] == \"assistant\" and not follows_tool:\n            content = message[\"content\"]\n\n            # Strip thinking from content\n            if isinstance(content, list):\n                # Remove ThinkingPart, keep only text\n                text_content = \"\".join(p[\"text\"] for p in content if p[\"type\"] == \"text\")\n            else:\n                # Strip <think>...</think> blocks from string content\n                text_content = re.sub(r\"<think>.*?</think>\", \"\", content, flags=re.DOTALL)\n\n            message = message.copy()\n            message[\"content\"] = text_content\n\n        # Call parent to get base rendering\n        rendered = super().render_message(message, ctx)\n\n        # Add </think> to header for assistant messages (not following tool)\n        # This goes in header (weight=0) so observation matches generation prompt.\n        if message[\"role\"] == \"assistant\" and not follows_tool:\n            think_close_tokens = self.tokenizer.encode(\"</think>\", add_special_tokens=False)\n            old_header_tokens = list(rendered.header.tokens) if rendered.header else []\n            new_header = tinker.EncodedTextChunk(tokens=old_header_tokens + think_close_tokens)\n            rendered = RenderedMessage(header=new_header, output=rendered.output)\n\n        return rendered\n"
  },
  {
    "path": "tinker_cookbook/renderers/deepseek_v3_test.py",
    "content": "\"\"\"Tests specific to DeepSeek V3 renderers (parse_response, tool call behavior, streaming).\"\"\"\n\nimport pytest\nimport tinker\n\nfrom tinker_cookbook.renderers import (\n    Message,\n    RenderContext,\n    StreamingMessageHeader,\n    TextPart,\n    ThinkingPart,\n    ToolCall,\n)\nfrom tinker_cookbook.renderers.base import ensure_list\nfrom tinker_cookbook.renderers.deepseek_v3 import (\n    DeepSeekV3DisableThinkingRenderer,\n    DeepSeekV3ThinkingRenderer,\n)\nfrom tinker_cookbook.renderers.testing_utils import skip_deepseek_tokenizer_bug\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\npytestmark = skip_deepseek_tokenizer_bug\n\n# =============================================================================\n# DeepSeek parse_response Tests\n# =============================================================================\n\n\ndef test_deepseek_parse_response_extracts_thinking():\n    \"\"\"Test DeepSeekV3ThinkingRenderer.parse_response extracts thinking.\"\"\"\n    tokenizer = get_tokenizer(\"deepseek-ai/DeepSeek-V3.1\")\n    renderer = DeepSeekV3ThinkingRenderer(tokenizer)\n\n    # Note: DeepSeek uses full-width pipes in special tokens\n    response_str = \"Let me think about this.</think>The answer is 42.<｜end▁of▁sentence｜>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n\n    thinking_parts = [p for p in content if p[\"type\"] == \"thinking\"]\n    text_parts = [p for p in content if p[\"type\"] == \"text\"]\n\n    assert len(thinking_parts) == 1\n    assert thinking_parts[0][\"thinking\"] == \"Let me think about this.\"\n    assert len(text_parts) == 1\n    assert text_parts[0][\"text\"] == \"The answer is 42.\"\n\n\ndef test_deepseek_parse_response_no_thinking_returns_string():\n    \"\"\"Test DeepSeekV3ThinkingRenderer.parse_response returns string when no thinking.\"\"\"\n    tokenizer = get_tokenizer(\"deepseek-ai/DeepSeek-V3.1\")\n    renderer = DeepSeekV3ThinkingRenderer(tokenizer)\n\n    response_str = \"Just a plain response.<｜end▁of▁sentence｜>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    assert isinstance(message[\"content\"], str)\n    assert message[\"content\"] == \"Just a plain response.\"\n\n\ndef test_deepseek_parse_response_multiple_think_blocks():\n    \"\"\"Test DeepSeekV3ThinkingRenderer.parse_response handles multiple think blocks.\"\"\"\n    tokenizer = get_tokenizer(\"deepseek-ai/DeepSeek-V3.1\")\n    renderer = DeepSeekV3ThinkingRenderer(tokenizer)\n\n    response_str = \"step 1</think>partial<think>step 2</think>final<｜end▁of▁sentence｜>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n    assert len(content) == 4\n\n    assert content[0] == ThinkingPart(type=\"thinking\", thinking=\"step 1\")\n    assert content[1] == TextPart(type=\"text\", text=\"partial\")\n    assert content[2] == ThinkingPart(type=\"thinking\", thinking=\"step 2\")\n    assert content[3] == TextPart(type=\"text\", text=\"final\")\n\n\n# =============================================================================\n# DeepSeek Tool Call / Formatting Tests\n# =============================================================================\n\n\ndef test_deepseek_thinking_preserved_with_tool_calls():\n    \"\"\"\n    Test that thinking is preserved in messages that have tool_calls.\n    The thinking represents the model's reasoning about WHY it's making the tool call.\n    \"\"\"\n    model_name = \"deepseek-ai/DeepSeek-V3.1\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = DeepSeekV3ThinkingRenderer(tokenizer)  # Default strip_thinking_from_history=True\n\n    messages: list[Message] = [\n        {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": \"<think>I need to check the weather.</think>Let me look that up.\",\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"content\": '{\"temperature\": 72}',\n            \"tool_call_id\": \"call_1\",\n        },\n        {\"role\": \"assistant\", \"content\": \"The temperature in NYC is 72°F.\"},\n    ]\n\n    model_input, _ = renderer.build_supervised_example(messages)\n    decoded = tokenizer.decode(model_input.to_ints())\n\n    # Thinking in message with tool_calls should be preserved\n    assert \"I need to check the weather\" in decoded, (\n        f\"Thinking in tool_call message should be preserved: {decoded}\"\n    )\n\n\ndef test_deepseek_post_tool_formatting():\n    \"\"\"\n    Test that assistant messages following tool responses have correct formatting.\n    Post-tool assistant messages should not have the role token or </think> prefix.\n    \"\"\"\n    model_name = \"deepseek-ai/DeepSeek-V3.1\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = DeepSeekV3ThinkingRenderer(tokenizer)\n\n    messages: list[Message] = [\n        {\"role\": \"user\", \"content\": \"What's the weather?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": \"Let me check.\",\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"content\": '{\"temperature\": 72}',\n            \"tool_call_id\": \"call_1\",\n        },\n        {\"role\": \"assistant\", \"content\": \"The temperature is 72°F.\"},\n    ]\n\n    for idx, message in enumerate(messages):\n        ctx = RenderContext(\n            idx=idx,\n            is_last=idx == len(messages) - 1,\n            prev_message=messages[idx - 1] if idx > 0 else None,\n        )\n        follows_tool = ctx.prev_message is not None and ctx.prev_message[\"role\"] == \"tool\"\n        rendered = renderer.render_message(message, ctx)\n\n        if message[\"role\"] == \"assistant\" and follows_tool:\n            # Post-tool assistant should have no header (no role token)\n            header = rendered.header\n            assert header is None or len(header.tokens) == 0, (\n                f\"Post-tool assistant should have no header, got: {header}\"\n            )\n\n            # Output should not start with </think>\n            output_chunk = rendered.output[0]\n            assert isinstance(output_chunk, tinker.EncodedTextChunk), \"Expected EncodedTextChunk\"\n            output_str = str(tokenizer.decode(list(output_chunk.tokens)))\n            assert not output_str.startswith(\"</think>\"), (\n                f\"Post-tool assistant should not have </think> prefix: {output_str}\"\n            )\n\n\n# =============================================================================\n# DeepSeek Streaming Tests\n# =============================================================================\n\n\ndef _is_message(obj) -> bool:\n    return isinstance(obj, dict) and \"role\" in obj and \"content\" in obj\n\n\ndef _assert_deepseek_streaming_matches_batch(renderer, response_str: str):\n    \"\"\"Helper: verify streaming and batch parsing produce identical results.\"\"\"\n    tokenizer = renderer.tokenizer\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    batch_message, batch_success = renderer.parse_response(response_tokens)\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert len(deltas) >= 2, \"Should have at least header + final message\"\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert _is_message(deltas[-1])\n\n    streaming_message = deltas[-1]\n    assert streaming_message[\"role\"] == batch_message[\"role\"]\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n\n    return deltas, batch_message\n\n\nclass TestDeepSeekStreamingBatchEquivalence:\n    \"\"\"Verify parse_response_streaming matches parse_response for DeepSeek patterns.\"\"\"\n\n    @pytest.fixture\n    def thinking_renderer(self):\n        tokenizer = get_tokenizer(\"deepseek-ai/DeepSeek-V3.1\")\n        return DeepSeekV3ThinkingRenderer(tokenizer)\n\n    @pytest.fixture\n    def non_thinking_renderer(self):\n        tokenizer = get_tokenizer(\"deepseek-ai/DeepSeek-V3.1\")\n        return DeepSeekV3DisableThinkingRenderer(tokenizer)\n\n    def test_simple_text(self, thinking_renderer):\n        _assert_deepseek_streaming_matches_batch(\n            thinking_renderer, \"Hello, world!<｜end▁of▁sentence｜>\"\n        )\n\n    def test_thinking_then_text(self, thinking_renderer):\n        _assert_deepseek_streaming_matches_batch(\n            thinking_renderer,\n            \"Let me think about this.</think>The answer is 42.<｜end▁of▁sentence｜>\",\n        )\n\n    def test_multiple_think_blocks(self, thinking_renderer):\n        _assert_deepseek_streaming_matches_batch(\n            thinking_renderer,\n            \"step 1</think>partial<think>step 2</think>final<｜end▁of▁sentence｜>\",\n        )\n\n    def test_empty_response(self, thinking_renderer):\n        _assert_deepseek_streaming_matches_batch(thinking_renderer, \"<｜end▁of▁sentence｜>\")\n\n    def test_non_thinking_renderer(self, non_thinking_renderer):\n        _assert_deepseek_streaming_matches_batch(\n            non_thinking_renderer, \"Direct answer.<｜end▁of▁sentence｜>\"\n        )\n\n    def test_no_end_token(self, thinking_renderer):\n        \"\"\"Truncated response — streaming should still parse think blocks.\"\"\"\n        tokenizer = thinking_renderer.tokenizer\n        response_tokens = tokenizer.encode(\"reasoning</think>partial\", add_special_tokens=False)\n\n        deltas = list(thinking_renderer.parse_response_streaming(response_tokens))\n        final = deltas[-1]\n        assert _is_message(final)\n        content = final[\"content\"]\n        assert isinstance(content, list), \"Truncated response should still parse think blocks\"\n        thinking = [p for p in content if p[\"type\"] == \"thinking\"]\n        text = [p for p in content if p[\"type\"] == \"text\"]\n        assert len(thinking) == 1 and thinking[0][\"thinking\"] == \"reasoning\"\n        assert len(text) == 1 and text[0][\"text\"] == \"partial\"\n"
  },
  {
    "path": "tinker_cookbook/renderers/gpt_oss.py",
    "content": "\"\"\"GptOssRenderer - OpenAI's open source model format (Harmony).\"\"\"\n\nimport json\nimport re\nimport warnings\nfrom datetime import datetime\n\nimport tinker\nimport torch\n\nfrom tinker_cookbook.exceptions import RendererError\nfrom tinker_cookbook.renderers.base import (\n    ContentPart,\n    Message,\n    RenderContext,\n    RenderedMessage,\n    Renderer,\n    Role,\n    TextPart,\n    ThinkingPart,\n    ToolCall,\n    ToolSpec,\n    TrainOnWhat,\n    UnparsedToolCall,\n    ensure_list,\n    ensure_text,\n)\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n# =============================================================================\n# TypeScript formatting utilities (stateless, used for Harmony tool definitions)\n# =============================================================================\n\n\ndef _json_type_to_typescript(schema: dict) -> str:\n    \"\"\"Convert a single JSON schema type to TypeScript.\"\"\"\n    if \"oneOf\" in schema:\n        return \" | \".join(_json_type_to_typescript(s) for s in schema[\"oneOf\"])\n    if \"anyOf\" in schema:\n        return \" | \".join(_json_type_to_typescript(s) for s in schema[\"anyOf\"])\n\n    json_type = schema.get(\"type\", \"any\")\n\n    if isinstance(json_type, list):\n        return \" | \".join(_json_type_to_typescript({\"type\": t}) for t in json_type)\n\n    if json_type == \"string\":\n        if \"enum\" in schema:\n            return \" | \".join(json.dumps(v) for v in schema[\"enum\"])\n        base_type = \"string\"\n    elif json_type == \"number\" or json_type == \"integer\":\n        base_type = \"number\"\n    elif json_type == \"boolean\":\n        base_type = \"boolean\"\n    elif json_type == \"array\":\n        items_type = _json_type_to_typescript(schema.get(\"items\", {}))\n        base_type = f\"{items_type}[]\"\n    elif json_type == \"object\":\n        base_type = _json_schema_to_typescript(schema)\n    else:\n        base_type = \"any\"\n\n    if schema.get(\"nullable\"):\n        return f\"{base_type} | null\"\n    return base_type\n\n\ndef _json_schema_to_typescript(schema: dict) -> str:\n    \"\"\"Convert JSON schema to an inline TypeScript-ish type string.\"\"\"\n    if schema.get(\"type\") != \"object\":\n        return \"any\"\n\n    properties = schema.get(\"properties\", {})\n    required = set(schema.get(\"required\", []))\n\n    type_parts = []\n    for prop_name, prop_schema in properties.items():\n        prop_type = _json_type_to_typescript(prop_schema)\n        optional = \"\" if prop_name in required else \"?\"\n        type_parts.append(f\"{prop_name}{optional}: {prop_type}\")\n\n    return \"{ \" + \", \".join(type_parts) + \" }\"\n\n\ndef _schema_comments(schema: dict) -> list[str]:\n    \"\"\"Extract comments from schema (title, description, examples).\"\"\"\n    comments: list[str] = []\n    title = schema.get(\"title\")\n    if title:\n        comments.append(str(title))\n        comments.append(\"\")\n    description = schema.get(\"description\")\n    if description:\n        comments.append(str(description))\n    examples = schema.get(\"examples\")\n    if examples:\n        comments.append(\"Examples:\")\n        for example in examples:\n            comments.append(f\"- {json.dumps(example)}\")\n    return comments\n\n\ndef _format_parameters_block(schema: dict) -> str:\n    \"\"\"Format function parameters as a TypeScript-style block.\"\"\"\n    if schema.get(\"type\") != \"object\" or not schema.get(\"properties\"):\n        return \"()\"\n\n    lines = []\n    header = \"(_:\"\n    schema_description = schema.get(\"description\")\n    if schema_description:\n        header += f\" // {schema_description}\"\n    lines.append(header)\n    lines.append(\"{\")\n\n    properties = schema.get(\"properties\", {})\n    required = set(schema.get(\"required\", []))\n    for prop_name, prop_schema in properties.items():\n        for comment in _schema_comments(prop_schema):\n            lines.append(f\"// {comment}\")\n        prop_type = _json_type_to_typescript(prop_schema)\n        optional = \"\" if prop_name in required else \"?\"\n        default_comment = \"\"\n        if \"default\" in prop_schema:\n            default_comment = f\" // default: {json.dumps(prop_schema['default'])}\"\n        lines.append(f\"{prop_name}{optional}: {prop_type},{default_comment}\")\n\n    lines.append(\"})\")\n    return \"\\n\".join(lines)\n\n\ndef _format_tool_definition(tool: ToolSpec) -> str:\n    \"\"\"Format a single tool as a Harmony TypeScript-style definition.\"\"\"\n    lines = []\n    if tool.get(\"description\"):\n        lines.append(f\"// {tool['description']}\")\n\n    params = tool.get(\"parameters\") or {}\n    params_block = _format_parameters_block(params)\n    lines.append(f\"type {tool['name']} = {params_block} => any;\")\n    return \"\\n\".join(lines)\n\n\nclass GptOssRenderer(Renderer):\n    \"\"\"\n    Renderer for OpenAI's open source models using the Harmony format.\n\n    Wire format: <|start|>role<|channel|>channel<|message|>content<|end|>\n    No newlines between messages. Last assistant message ends with <|return|>;\n    historical assistant messages end with <|end|>.\n\n    Harmony Channels\n    ----------------\n    Each assistant message specifies a \"channel\" that controls how the content is\n    interpreted and displayed. An assistant turn can have multiple channel segments\n    (rendered as separate <|start|>assistant... blocks):\n\n    - analysis: Chain-of-thought reasoning (hidden from end users, like <think> blocks)\n    - commentary: Tool calls to developer-defined functions, or user-visible \"preambles\"\n      before tool calls. Uses `to=functions.name` to route to specific tools.\n    - final: The user-facing response text\n\n    A typical assistant turn with thinking + tool call + final answer would render as:\n        <|start|>assistant<|channel|>analysis<|message|>{thinking}<|end|>\n        <|start|>assistant to=functions.get_weather<|channel|>commentary <|constrain|>json<|message|>{args}<|call|>\n        ... (tool result) ...\n        <|start|>assistant<|channel|>final<|message|>{answer}<|return|>\n\n    Tool Calling\n    ------------\n    - Tool definitions: Go in developer message with TypeScript-style syntax\n    - Tool calls: <|start|>assistant to=functions.name<|channel|>commentary <|constrain|>json<|message|>{args}<|call|>\n    - Tool results: <|start|>functions.name to=assistant<|channel|>commentary<|message|>{result}<|end|>\n\n    Reference: https://raw.githubusercontent.com/openai/openai-cookbook/main/articles/openai-harmony.md\n    \"\"\"\n\n    # System prompt content (without rendering tokens). Tool channel instructions are NOT\n    # included here; they are only added when tools are defined in the developer message.\n    system_prompt_content = (\n        \"You are ChatGPT, a large language model trained by OpenAI.\\n\"\n        \"Knowledge cutoff: 2024-06\\n\"\n        \"Current date: {current_date}\\n\\n\"\n        \"Reasoning: {reasoning_effort}\\n\\n\"\n        \"# Valid channels: analysis, commentary, final. Channel must be included for every message.\"\n    )\n    use_system_prompt: bool = False\n    reasoning_effort: str | None = None\n    current_date: str | None = (\n        None  # If use_system_prompt=True, will use the current date if this is None. Set this to a fixed date for deterministic system prompt.\n    )\n\n    def __init__(\n        self,\n        tokenizer: Tokenizer,\n        use_system_prompt: bool = False,\n        reasoning_effort: str | None = None,\n        current_date: str | None = None,\n    ):\n        super().__init__(tokenizer)\n        self.use_system_prompt = use_system_prompt\n        self.reasoning_effort = reasoning_effort\n        self.current_date = current_date\n        assert use_system_prompt == (reasoning_effort is not None), (\n            \"Reasoning effort must be set iff using system prompt\"\n        )\n\n    # Internal role for OpenAI's system prompt (bypasses system->developer mapping)\n    _INTERNAL_SYSTEM_ROLE = \"_gptoss_internal_system\"\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        role = message[\"role\"]\n\n        # Handle tool result messages (role=\"tool\")\n        if role == \"tool\":\n            return self._render_tool_result_message(message, ctx)\n\n        # Internal system role renders as actual \"system\" without transformation\n        if role == self._INTERNAL_SYSTEM_ROLE:\n            role = \"system\"\n        # User-provided \"system\" messages map to \"developer\" (per HF template)\n        elif role == \"system\":\n            role = \"developer\"\n\n        header_str = f\"<|start|>{role}\"\n        output_str = \"\"\n        tool_calls: list[ToolCall] = []\n\n        if message[\"role\"] == \"assistant\":\n            # Assistant channels. See https://cookbook.openai.com/articles/openai-harmony\n            # Extract text and thinking from content list\n            parts = ensure_list(message[\"content\"])\n            text_content = \"\".join(p[\"text\"] for p in parts if p[\"type\"] == \"text\")\n            thinking_content = \"\".join(p[\"thinking\"] for p in parts if p[\"type\"] == \"thinking\")\n            tool_calls = message.get(\"tool_calls\") or []\n\n            # Analysis channel (CoT) - only if there's thinking content\n            if thinking_content:\n                output_str += (\n                    f\"<|channel|>analysis<|message|>{thinking_content}<|end|><|start|>assistant\"\n                )\n\n            # Handle tool calls (goes in commentary channel)\n            if tool_calls:\n                # If there's text content with tool calls, render as commentary preamble first\n                if text_content:\n                    output_str += (\n                        f\"<|channel|>commentary<|message|>{text_content}<|end|><|start|>assistant\"\n                    )\n                output_str += self._render_tool_calls(tool_calls)\n            else:\n                # Final channel (Response Content)\n                output_str += f\"<|channel|>final<|message|>{text_content}\"\n        elif message[\"role\"] == \"system\":\n            # User-provided system messages get \"# Instructions\" wrapper (rendered as developer)\n            output_str += f\"<|message|># Instructions\\n\\n{ensure_text(message['content'])}\\n\\n\"\n        else:\n            # user, developer, internal system, and other roles: plain content\n            output_str += f\"<|message|>{ensure_text(message['content'])}\"\n\n        # End token logic:\n        # - Tool calls: each tool call already includes <|call|> via _render_tool_calls, no end token needed\n        # - Assistant (no tool calls): <|return|> if last message, <|end|> otherwise\n        # - All other roles: <|end|>\n        if message[\"role\"] == \"assistant\":\n            if not tool_calls:\n                if ctx.is_last:\n                    output_str += \"<|return|>\"\n                else:\n                    output_str += \"<|end|>\"\n            # Note: tool_calls case needs no end token here - _render_tool_calls adds <|call|>\n        else:\n            output_str += \"<|end|>\"\n\n        header = tinker.types.EncodedTextChunk(\n            tokens=self.tokenizer.encode(header_str, add_special_tokens=False)\n        )\n        output: list[tinker.ModelInputChunk] = [\n            tinker.types.EncodedTextChunk(\n                tokens=self.tokenizer.encode(output_str, add_special_tokens=False)\n            )\n        ]\n        return RenderedMessage(header=header, output=output)\n\n    def _render_tool_calls(self, tool_calls: list[ToolCall]) -> str:\n        \"\"\"Render tool calls in Harmony commentary channel format.\n\n        Each tool call becomes a separate commentary message:\n        to=functions.name<|channel|>commentary <|constrain|>json<|message|>{args}\n\n        Multiple tool calls are separated by <|call|><|start|>assistant.\n        \"\"\"\n        result_parts = []\n        for i, tc in enumerate(tool_calls):\n            # Format: to=functions.name<|channel|>commentary <|constrain|>json<|message|>{args}\n            result_parts.append(\n                f\" to=functions.{tc.function.name}<|channel|>commentary <|constrain|>json<|message|>\"\n                f\"{tc.function.arguments}<|call|>\"\n            )\n            # If not the last tool call, close message and start new assistant message\n            if i < len(tool_calls) - 1:\n                result_parts.append(\"<|start|>assistant\")\n        return \"\".join(result_parts)\n\n    def _render_tool_result_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        \"\"\"Render a tool result message.\n\n        Format: <|start|>functions.name to=assistant<|channel|>commentary<|message|>{result}<|end|>\n\n        IMPORTANT: The tool name MUST be provided in the message's \"name\" field.\n        The renderer is stateless and cannot track tool_call_id -> name mappings.\n        When constructing tool result messages, always include the \"name\" field:\n\n            {\"role\": \"tool\", \"name\": \"get_weather\", \"content\": \"72 degrees\", \"tool_call_id\": \"...\"}\n\n        If \"name\" is missing, this will produce \"functions.unknown\" which is incorrect.\n        \"\"\"\n        # Get the tool name from the \"name\" field\n        tool_name = message.get(\"name\", \"\")\n        if not tool_name:\n            warnings.warn(\n                \"Tool message missing 'name' field. GptOssRenderer requires the 'name' field \"\n                \"to render tool results correctly. Add 'name' to your tool messages: \"\n                \"{'role': 'tool', 'name': 'function_name', 'content': '...', 'tool_call_id': '...'}\",\n                UserWarning,\n                stacklevel=3,\n            )\n            tool_name = \"unknown\"\n\n        # Ensure qualified with \"functions.\" prefix\n        if not tool_name.startswith(\"functions.\"):\n            tool_name = f\"functions.{tool_name}\"\n\n        # Build the header with tool name as role and to=assistant\n        header_str = f\"<|start|>{tool_name} to=assistant\"\n\n        # Tool results go in commentary channel\n        content = ensure_text(message[\"content\"])\n        output_str = f\"<|channel|>commentary<|message|>{content}<|end|>\"\n\n        header = tinker.types.EncodedTextChunk(\n            tokens=self.tokenizer.encode(header_str, add_special_tokens=False)\n        )\n        output: list[tinker.ModelInputChunk] = [\n            tinker.types.EncodedTextChunk(\n                tokens=self.tokenizer.encode(output_str, add_special_tokens=False)\n            )\n        ]\n        return RenderedMessage(header=header, output=output)\n\n    def _get_system_message(self) -> Message | None:\n        \"\"\"Return system message if configured, else None.\n\n        Uses internal role to render as actual 'system' (not mapped to 'developer').\n        \"\"\"\n        if not self.use_system_prompt:\n            return None\n        current_date = (\n            self.current_date\n            if self.current_date is not None\n            else datetime.now().strftime(\"%Y-%m-%d\")\n        )\n        content = self.system_prompt_content.format(\n            current_date=current_date,\n            reasoning_effort=self.reasoning_effort,\n        )\n        return Message(role=self._INTERNAL_SYSTEM_ROLE, content=content)\n\n    @property\n    def _bos_tokens(self) -> list[int]:\n        # GptOss has no BOS token. System prompt is prepended as a message.\n        return []\n\n    def _warn_if_user_system_message(self, messages: list[Message]) -> None:\n        \"\"\"Warn if user provides system message when use_system_prompt=True.\"\"\"\n        if self.use_system_prompt and messages and messages[0][\"role\"] == \"system\":\n            warnings.warn(\n                \"use_system_prompt=True but messages already start with a system message. \"\n                \"The built-in system prompt will be prepended, resulting in two system messages. \"\n                \"Either set use_system_prompt=False or remove the system message from your messages.\",\n                UserWarning,\n                stacklevel=3,\n            )\n\n    def build_generation_prompt(\n        self, messages: list[Message], role: Role = \"assistant\", prefill: str | None = None\n    ) -> tinker.ModelInput:\n        \"\"\"Build generation prompt, prepending system message if configured.\"\"\"\n        self._warn_if_user_system_message(messages)\n        system_msg = self._get_system_message()\n        if system_msg:\n            messages = [system_msg] + list(messages)\n        return super().build_generation_prompt(messages, role, prefill)\n\n    def build_supervised_example(\n        self,\n        messages: list[Message],\n        train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE,\n    ) -> tuple[tinker.ModelInput, torch.Tensor]:\n        \"\"\"Build supervised example, prepending system message if configured.\"\"\"\n        self._warn_if_user_system_message(messages)\n        system_msg = self._get_system_message()\n        if system_msg:\n            messages = [system_msg] + list(messages)\n        return super().build_supervised_example(messages, train_on_what)\n\n    @property\n    def _return_token(self) -> int:\n        res = self.tokenizer.encode(\"<|return|>\", add_special_tokens=False)\n        assert len(res) == 1, f\"Expected single token for <|return|>, got {len(res)}\"\n        return res[0]\n\n    @property\n    def _call_token(self) -> int:\n        res = self.tokenizer.encode(\"<|call|>\", add_special_tokens=False)\n        assert len(res) == 1, f\"Expected single token for <|call|>, got {len(res)}\"\n        return res[0]\n\n    def get_stop_sequences(self) -> list[int]:\n        # Both <|return|> and <|call|> are stop tokens\n        # <|return|> for normal completion, <|call|> for tool calls\n        return [self._return_token, self._call_token]\n\n    def parse_response(self, response: list[int]) -> tuple[Message, bool]:\n        call_count = response.count(self._call_token)\n        return_count = response.count(self._return_token)\n        if call_count == 0 and return_count == 0:\n            str_response = str(self.tokenizer.decode(response))\n            return Message(role=\"assistant\", content=str_response), False\n        if call_count > 1:\n            raise RendererError(\n                f\"When parsing response, expected at most 1 <|call|> token, but got {call_count}. \"\n                \"You probably are using the wrong stop tokens when sampling\"\n            )\n        if return_count > 1:\n            raise RendererError(\n                f\"When parsing response, expected at most 1 <|return|> token, but got {return_count}. \"\n                \"You probably are using the wrong stop tokens when sampling\"\n            )\n\n        stop_idx = response.index(self._return_token) if return_count else None\n        if call_count:\n            call_idx = response.index(self._call_token)\n            if stop_idx is None or call_idx < stop_idx:\n                stop_idx = call_idx\n\n        assert stop_idx is not None\n        str_response = str(self.tokenizer.decode(response[:stop_idx]))\n        parts, tool_calls, unparsed = self._parse_harmony_output(str_response)\n        content: list[ContentPart] | str = parts if parts else str_response\n\n        message: Message = {\"role\": \"assistant\", \"content\": content}\n        if tool_calls:\n            message[\"tool_calls\"] = tool_calls\n        if unparsed:\n            message[\"unparsed_tool_calls\"] = unparsed\n\n        return message, True\n\n    def to_openai_message(self, message: Message) -> dict:\n        \"\"\"Convert a Message to OpenAI API format with reasoning_content for thinking.\n\n        GptOss uses the analysis channel for thinking, which maps to reasoning_content\n        in OpenAI's API format.\n        \"\"\"\n        result: dict = {\"role\": message[\"role\"]}\n\n        content = message[\"content\"]\n        if isinstance(content, str):\n            result[\"content\"] = content\n        else:\n            # Extract thinking into reasoning_content, keep text in content\n            thinking_parts = []\n            text_parts = []\n            for p in content:\n                if p[\"type\"] == \"thinking\":\n                    thinking_parts.append(p[\"thinking\"])\n                elif p[\"type\"] == \"text\":\n                    text_parts.append(p[\"text\"])\n\n            result[\"content\"] = \"\".join(text_parts)\n            if thinking_parts:\n                result[\"reasoning_content\"] = \"\".join(thinking_parts)\n\n        # Handle tool_calls\n        if \"tool_calls\" in message and message[\"tool_calls\"]:  # noqa: RUF019\n            result[\"tool_calls\"] = [\n                {\n                    \"type\": \"function\",\n                    \"id\": tc.id,\n                    \"function\": {\n                        \"name\": tc.function.name,\n                        \"arguments\": tc.function.arguments,\n                    },\n                }\n                for tc in message[\"tool_calls\"]\n            ]\n\n        # Handle tool response fields\n        if message[\"role\"] == \"tool\":\n            if \"tool_call_id\" in message:\n                result[\"tool_call_id\"] = message[\"tool_call_id\"]\n            if \"name\" in message:\n                result[\"name\"] = message[\"name\"]\n\n        return result\n\n    def _parse_harmony_output(\n        self, content: str\n    ) -> tuple[list[ContentPart], list[ToolCall], list[UnparsedToolCall]]:\n        messages = self._parse_harmony_messages(content)\n        parts: list[ContentPart] = []\n        tool_calls: list[ToolCall] = []\n        unparsed: list[UnparsedToolCall] = []\n\n        for msg in messages:\n            msg_content = msg[\"content\"] or \"\"\n            msg_raw_text = msg[\"raw_text\"] or \"\"\n            if not msg_content.strip():\n                continue\n\n            recipient = msg[\"recipient\"]\n            if recipient and recipient.startswith(\"functions.\"):\n                tool_name = recipient.split(\"functions.\", 1)[1]\n                try:\n                    json.loads(msg_content)\n                    tool_calls.append(\n                        ToolCall(\n                            function=ToolCall.FunctionBody(\n                                name=tool_name, arguments=msg_content.strip()\n                            ),\n                            id=None,  # Harmony format doesn't include tool call IDs\n                        )\n                    )\n                except json.JSONDecodeError as e:\n                    unparsed.append(\n                        UnparsedToolCall(raw_text=msg_raw_text, error=f\"Invalid JSON: {e}\")\n                    )\n                continue\n\n            channel = msg[\"channel\"]\n            if channel == \"analysis\":\n                parts.append(ThinkingPart(type=\"thinking\", thinking=msg_content))\n            elif channel == \"final\" or channel == \"commentary\":\n                parts.append(TextPart(type=\"text\", text=msg_content))\n\n        return parts, tool_calls, unparsed\n\n    def _parse_harmony_messages(self, content: str) -> list[dict[str, str | None]]:\n        \"\"\"Parse Harmony format content into a list of message dicts.\n\n        Uses manual string parsing (find/rfind) rather than regex. This approach\n        is intentional: it will continue to work if we move away from using\n        stringified tokens, which would be preferable for robustness.\n        \"\"\"\n        messages: list[dict[str, str | None]] = []\n        idx = 0\n        message_token = \"<|message|>\"\n        end_tokens = (\"<|end|>\", \"<|call|>\", \"<|return|>\")\n\n        while True:\n            message_idx = content.find(message_token, idx)\n            if message_idx == -1:\n                break\n\n            header_start = content.rfind(\"<|start|>\", idx, message_idx)\n            if header_start == -1:\n                header_start = idx\n            header = content[header_start:message_idx]\n\n            content_start = message_idx + len(message_token)\n            end_idx = len(content)\n            end_token = \"\"\n            for token in end_tokens:\n                token_idx = content.find(token, content_start)\n                if token_idx != -1 and token_idx < end_idx:\n                    end_idx = token_idx\n                    end_token = token\n\n            body = content[content_start:end_idx]\n\n            channel = None\n            channel_match = re.search(r\"<\\|channel\\|>([^<\\s]+)\", header)\n            if channel_match:\n                channel = channel_match.group(1)\n\n            recipient = None\n            recipient_match = re.search(r\"to=([^\\s<]+)\", header)\n            if recipient_match:\n                recipient = recipient_match.group(1)\n\n            content_type = None\n            content_type_match = re.search(r\"<\\|constrain\\|>\\s*([^\\s<]+)\", header)\n            if content_type_match:\n                content_type = content_type_match.group(1)\n\n            messages.append(\n                {\n                    \"channel\": channel,\n                    \"recipient\": recipient,\n                    \"content_type\": content_type,\n                    \"content\": body,\n                    \"raw_text\": content[header_start : end_idx + len(end_token)]\n                    if end_token\n                    else content[header_start:],\n                }\n            )\n\n            idx = end_idx + len(end_token)\n\n        return messages\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        \"\"\"Create conversation prefix with tools in Harmony format.\n\n        Returns a list of messages to prepend to conversations:\n        1. If tools present: A system message with tool routing instruction\n        2. A developer message with user instructions and tool definitions\n\n        Tools are defined using TypeScript-ish syntax in a `functions` namespace,\n        following the OpenAI Harmony spec.\n\n        Note: When using this with tools, you typically don't need use_system_prompt=True\n        since this method provides the necessary system setup for tool routing.\n\n        Reference: https://raw.githubusercontent.com/openai/openai-cookbook/main/articles/openai-harmony.md\n        \"\"\"\n        messages: list[Message] = []\n\n        # Tool routing instruction goes in system message (per Harmony spec)\n        if tools:\n            messages.append(\n                Message(\n                    role=self._INTERNAL_SYSTEM_ROLE,\n                    content=\"Calls to these tools must go to the commentary channel: 'functions'.\",\n                )\n            )\n\n        # User instructions and tool definitions go in developer message\n        content_parts: list[str] = []\n        if system_prompt:\n            content_parts.append(f\"# Instructions\\n\\n{system_prompt}\")\n\n        if tools:\n            tool_defs = [_format_tool_definition(tool) for tool in tools]\n            tools_text = \"\\n\\n\".join(tool_defs)\n            content_parts.append(\n                \"# Tools\\n\\n## functions\\n\\nnamespace functions {\\n\\n\"\n                f\"{tools_text}\\n\\n\"\n                \"} // namespace functions\"\n            )\n\n        if content_parts:\n            content = \"\\n\\n\".join(content_parts)\n            messages.append(Message(role=\"developer\", content=content))\n\n        return messages\n"
  },
  {
    "path": "tinker_cookbook/renderers/gpt_oss_test.py",
    "content": "\"\"\"Tests specific to GptOss renderer (parse_response, tool calls, channel parsing).\"\"\"\n\nfrom tinker_cookbook.renderers import (\n    TextPart,\n    ThinkingPart,\n)\nfrom tinker_cookbook.renderers.gpt_oss import GptOssRenderer\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n# =============================================================================\n# GptOss parse_response Tests\n# =============================================================================\n\n\ndef test_gptoss_parse_response_extracts_thinking():\n    \"\"\"Test GptOssRenderer.parse_response extracts analysis channel as thinking.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    # GptOss format: analysis channel then final channel\n    response_str = \"<|channel|>analysis<|message|>Let me think about this.<|end|><|start|>assistant<|channel|>final<|message|>The answer is 42.<|return|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n\n    thinking_parts = [p for p in content if p[\"type\"] == \"thinking\"]\n    text_parts = [p for p in content if p[\"type\"] == \"text\"]\n\n    assert len(thinking_parts) == 1\n    assert thinking_parts[0][\"thinking\"] == \"Let me think about this.\"\n    assert len(text_parts) == 1\n    assert text_parts[0][\"text\"] == \"The answer is 42.\"\n\n\ndef test_gptoss_parse_response_multiple_analysis():\n    \"\"\"Test GptOssRenderer.parse_response handles multiple analysis messages.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    response_str = \"<|channel|>analysis<|message|>First thought.<|end|><|start|>assistant<|channel|>analysis<|message|>Second thought.<|end|><|start|>assistant<|channel|>final<|message|>Done.<|return|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n    assert len(content) == 3\n\n    assert content[0] == ThinkingPart(type=\"thinking\", thinking=\"First thought.\")\n    assert content[1] == ThinkingPart(type=\"thinking\", thinking=\"Second thought.\")\n    assert content[2] == TextPart(type=\"text\", text=\"Done.\")\n\n\ndef test_gptoss_parse_response_final_only():\n    \"\"\"Test GptOssRenderer.parse_response with only final channel.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    response_str = \"<|channel|>final<|message|>Simple answer.<|return|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n    assert len(content) == 1\n    assert content[0] == TextPart(type=\"text\", text=\"Simple answer.\")\n\n\ndef test_gptoss_parse_response_no_channels():\n    \"\"\"Test GptOssRenderer.parse_response returns string when no channel markers.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    response_str = \"Plain response without channels.<|return|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    # No channel markers, so content stays as string\n    assert isinstance(message[\"content\"], str)\n    assert message[\"content\"] == \"Plain response without channels.\"\n\n\ndef test_gptoss_parse_response_tool_call():\n    \"\"\"Test GptOssRenderer.parse_response extracts tool calls from commentary channel.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    # Tool call format: commentary channel with to=functions.name and <|call|> stop token\n    response_str = '<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"location\": \"San Francisco\"}<|call|>'\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 1\n    assert message[\"tool_calls\"][0].function.name == \"get_weather\"\n    assert '\"location\"' in message[\"tool_calls\"][0].function.arguments\n\n\ndef test_gptoss_parse_response_tool_call_with_analysis():\n    \"\"\"Test GptOssRenderer.parse_response extracts both thinking and tool calls.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    response_str = '<|channel|>analysis<|message|>I need to check the weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"city\": \"NYC\"}<|call|>'\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n\n    # Should have thinking from analysis channel\n    thinking_parts = [p for p in content if p[\"type\"] == \"thinking\"]\n    assert len(thinking_parts) >= 1\n    assert \"check the weather\" in thinking_parts[0][\"thinking\"]\n\n    # Should have tool call\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 1\n    assert message[\"tool_calls\"][0].function.name == \"get_weather\"\n\n\ndef test_gptoss_parse_response_invalid_tool_call_json():\n    \"\"\"Test GptOssRenderer.parse_response handles invalid JSON in tool calls.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    response_str = \"<|channel|>commentary to=functions.broken <|constrain|>json<|message|>not valid json<|call|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    assert \"unparsed_tool_calls\" in message\n    assert len(message[\"unparsed_tool_calls\"]) == 1\n    assert \"Invalid JSON\" in message[\"unparsed_tool_calls\"][0].error\n\n\ndef test_gptoss_parse_response_tool_call_recipient_before_channel():\n    \"\"\"Test GptOssRenderer.parse_response handles recipient before channel.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    response_str = '<|start|>assistant to=functions.get_weather<|channel|>commentary<|constrain|>json<|message|>{\"location\": \"Tokyo\"}<|call|>'\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 1\n    assert message[\"tool_calls\"][0].function.name == \"get_weather\"\n\n\ndef test_gptoss_parse_response_commentary_preamble():\n    \"\"\"Test GptOssRenderer.parse_response keeps commentary preamble text.\"\"\"\n    tokenizer = get_tokenizer(\"openai/gpt-oss-20b\")\n    renderer = GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort=\"medium\")\n\n    response_str = (\n        \"<|channel|>commentary<|message|>Checking now.<|end|>\"\n        '<|start|>assistant to=functions.get_weather<|channel|>commentary <|constrain|>json<|message|>{\"location\": \"SF\"}<|call|>'\n    )\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n    assert len(content) == 1\n    assert content[0] == TextPart(type=\"text\", text=\"Checking now.\")\n    assert \"tool_calls\" in message and len(message[\"tool_calls\"]) == 1\n"
  },
  {
    "path": "tinker_cookbook/renderers/kimi_k2.py",
    "content": "\"\"\"Renderer for Moonshot AI's Kimi K2 models.\"\"\"\n\nimport json\nimport re\nimport warnings\n\nimport tinker\nimport torch\n\nfrom tinker_cookbook.exceptions import RendererError\nfrom tinker_cookbook.renderers.base import (\n    ContentPart,\n    Message,\n    RenderContext,\n    RenderedMessage,\n    Renderer,\n    Role,\n    TextPart,\n    ToolCall,\n    ToolSpec,\n    TrainOnWhat,\n    UnparsedToolCall,\n    ensure_list,\n    ensure_text,\n    parse_response_for_stop_token,\n    parse_think_blocks,\n)\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n_TOOL_CALLS_SECTION_RE = re.compile(\n    r\"<\\|tool_calls_section_begin\\|>(.*?)<\\|tool_calls_section_end\\|>\"\n    r\"|<\\|tool_call_section_begin\\|>(.*?)<\\|tool_call_section_end\\|>\",\n    re.DOTALL,\n)\n_TOOL_CALL_RE = re.compile(\n    r\"<\\|tool_call_begin\\|>\\s*([^<]+:\\d+)\\s*<\\|tool_call_argument_begin\\|>\\s*(.*?)\\s*<\\|tool_call_end\\|>\",\n    re.DOTALL,\n)\n\n\ndef _split_tool_calls_section(content: str) -> tuple[str, str | None]:\n    match = _TOOL_CALLS_SECTION_RE.search(content)\n    if not match:\n        return content, None\n    tool_section = match.group(1) if match.group(1) is not None else match.group(2)\n    return content[: match.start()], tool_section\n\n\ndef _extract_tool_name(tool_id: str) -> str:\n    if not tool_id:\n        return \"\"\n    name_part = tool_id.split(\":\", 1)[0]\n    if \".\" in name_part:\n        _, name_part = name_part.split(\".\", 1)\n    return name_part\n\n\ndef _parse_tool_calls_section(\n    tool_section: str,\n) -> tuple[list[ToolCall], list[UnparsedToolCall]]:\n    tool_calls: list[ToolCall] = []\n    unparsed_tool_calls: list[UnparsedToolCall] = []\n\n    for match in _TOOL_CALL_RE.finditer(tool_section):\n        raw_text = match.group(0)\n        tool_id = match.group(1).strip()\n        args_str = match.group(2).strip()\n        func_name = _extract_tool_name(tool_id)\n\n        try:\n            json.loads(args_str)\n            tool_calls.append(\n                ToolCall(\n                    function=ToolCall.FunctionBody(name=func_name, arguments=args_str),\n                    id=tool_id if tool_id else None,\n                )\n            )\n        except json.JSONDecodeError as e:\n            unparsed_tool_calls.append(\n                UnparsedToolCall(raw_text=raw_text, error=f\"Invalid JSON: {e}\")\n            )\n\n    return tool_calls, unparsed_tool_calls\n\n\nclass KimiK2Renderer(Renderer):\n    \"\"\"\n    Format for moonshotai/Kimi-K2-Thinking:\n        <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>\n        <|im_user|>user<|im_middle|>What can you help me with?<|im_end|>\n        <|im_assistant|>assistant<|im_middle|><think>reasoning</think>I can help you with...<|im_end|>\n\n    Historical assistant messages use empty <think></think> blocks, while the assistant messages after the\n    last non-tool-call assistant message preserves reasoning_content in the thinking block.\n\n    Note: Per the HuggingFace chat template, the default system message is automatically\n    prepended if no system message is provided. This ensures train-eval consistency when\n    using HF's apply_chat_template for inference.\n    \"\"\"\n\n    supports_streaming = True\n\n    DEFAULT_SYSTEM_PROMPT = \"You are Kimi, an AI assistant created by Moonshot AI.\"\n\n    def __init__(self, tokenizer: Tokenizer, strip_thinking_from_history: bool = True):\n        super().__init__(tokenizer)\n        self.strip_thinking_from_history = strip_thinking_from_history\n\n    def _ensure_system_message(self, messages: list[Message]) -> list[Message]:\n        \"\"\"Ensure a default system message is present if none exists.\n\n        This matches the HuggingFace chat template behavior where a default system\n        message is automatically added when none is provided.\n\n        The default system message is inserted at the appropriate position:\n        - If messages is empty: adds default system message\n        - If starting with tool_declare: inserts default system after tool_declare (if no system message follows)\n        - Otherwise: prepends default system message before first message (if first message isn't system)\n        \"\"\"\n        if not messages:\n            default_system = Message(role=\"system\", content=self.DEFAULT_SYSTEM_PROMPT)\n            return [default_system]\n\n        # Accept both system and tool_declare as valid starting messages\n        first_role = messages[0][\"role\"]\n        if first_role == \"tool_declare\":\n            # Check if a system message already exists after tool_declare\n            if len(messages) >= 2 and messages[1][\"role\"] == \"system\":\n                return messages\n            # No system message, insert default after tool_declare\n            default_system = Message(role=\"system\", content=self.DEFAULT_SYSTEM_PROMPT)\n            return [messages[0], default_system] + list(messages[1:])\n        elif first_role != \"system\":\n            default_system = Message(role=\"system\", content=self.DEFAULT_SYSTEM_PROMPT)\n            return [default_system] + list(messages)\n\n        return messages\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        \"\"\"\n        Render a message. For assistant messages, ctx.is_last controls whether thinking is preserved\n        (True) or stripped to empty <think></think> (False).\n        \"\"\"\n        role = message[\"role\"]\n\n        # Build role token based on role type\n        if role == \"user\":\n            header_str = f\"<|im_user|>{role}<|im_middle|>\"\n        elif role == \"assistant\":\n            header_str = f\"<|im_assistant|>{role}<|im_middle|>\"\n        elif role == \"system\":\n            header_str = f\"<|im_system|>{role}<|im_middle|>\"\n        elif role == \"tool_declare\":\n            # Tool declaration uses system token but with \"tool_declare\" as display name\n            header_str = f\"<|im_system|>{role}<|im_middle|>\"\n        elif role == \"tool\":\n            # HF template uses message.name if present, otherwise role\n            role_name = message.get(\"name\")\n            if not role_name:\n                warnings.warn(\n                    \"Tool message missing 'name' field. Using 'tool' as fallback. \"\n                    \"Consider setting 'name' to match the tool function name for better context.\",\n                    UserWarning,\n                    stacklevel=3,\n                )\n                role_name = role\n            header_str = f\"<|im_system|>{role_name}<|im_middle|>\"\n\n            # Tool responses have special formatting - need tool_call_id to correlate with the call\n            tool_call_id = message.get(\"tool_call_id\", \"\")\n            if not tool_call_id:\n                warnings.warn(\n                    \"Tool message missing 'tool_call_id' field. KimiK2Renderer requires 'tool_call_id' \"\n                    \"to render tool results correctly. The value should match ToolCall.id from the \"\n                    \"assistant's tool_calls.\",\n                    UserWarning,\n                    stacklevel=3,\n                )\n            header_str += f\"## Return of {tool_call_id}\\n\"\n        else:\n            # Unknown roles default to system-style formatting\n            header_str = f\"<|im_system|>{role}<|im_middle|>\"\n\n        # Build output content\n        content = message[\"content\"]\n        output: list[tinker.ModelInputChunk] = []\n        if role == \"assistant\":\n            output_str = \"\"\n            # Extract thinking and text from content list\n            parts = ensure_list(content)\n            thinking_content = \"\".join(p[\"thinking\"] for p in parts if p[\"type\"] == \"thinking\")\n            text_content = \"\".join(p[\"text\"] for p in parts if p[\"type\"] == \"text\")\n\n            # Preserve thinking for the last assistant message, or for all messages\n            # when strip_thinking_from_history is False.\n            if (ctx.is_last or not self.strip_thinking_from_history) and thinking_content:\n                output_str = f\"<think>{thinking_content}</think>\"\n            else:\n                output_str = \"<think></think>\"\n            output_str += text_content\n\n            # Handle tool calls\n            if \"tool_calls\" in message and message[\"tool_calls\"]:  # noqa: RUF019\n                output_str += \"<|tool_calls_section_begin|>\"\n                for idx, tool_call in enumerate(message[\"tool_calls\"]):\n                    tool_id = tool_call.id\n                    if not tool_id:\n                        tool_id = f\"functions.{tool_call.function.name}:{idx}\"\n                    args = tool_call.function.arguments\n                    output_str += f\"<|tool_call_begin|>{tool_id}<|tool_call_argument_begin|>{args}<|tool_call_end|>\"\n                output_str += \"<|tool_calls_section_end|>\"\n            output_str += \"<|im_end|>\"\n            output.append(tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(output_str)))\n        elif isinstance(content, str) or (len(content) == 1 and content[0][\"type\"] == \"text\"):\n            # Single-part/text content\n            output_str = ensure_text(content) + \"<|im_end|>\"\n            output.append(tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(output_str)))\n        else:\n            # Mult-part content (e.g. text+image(s))\n            assert isinstance(content, list), f\"Expected list of content parts, got {type(content)}\"\n            output = self._encode_multipart_content(\n                content + [TextPart(type=\"text\", text=\"<|im_end|>\")]\n            )\n\n        header = tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(header_str))\n\n        return RenderedMessage(header=header, output=output)\n\n    def _encode_multipart_content(self, content: list[ContentPart]) -> list[tinker.ModelInputChunk]:\n        raise NotImplementedError(\n            \"Multipart/Image content encoding is not supported for Kimi K2 renderer\"\n        )\n\n    def build_generation_prompt(\n        self, messages: list[Message], role: Role = \"assistant\", prefill: str | None = None\n    ) -> tinker.ModelInput:\n        messages = self._ensure_system_message(messages)\n        chunks: list[tinker.types.ModelInputChunk] = []\n\n        # Find last assistant message without tool calls (matches hf template behavior).\n        last_assistant_idx = -1\n        for idx in range(len(messages) - 1, -1, -1):\n            if messages[idx][\"role\"] == \"assistant\" and not messages[idx].get(\"tool_calls\"):\n                last_assistant_idx = idx\n                break\n\n        for idx, message in enumerate(messages):\n            is_assistant = message[\"role\"] == \"assistant\"\n            is_last_assistant = is_assistant and (\n                last_assistant_idx == -1 or idx > last_assistant_idx\n            )\n\n            # We cannot simply set is_last=False since we might be generating a new assistant message following a tool response,\n            # and we need to preserve the thinking that leads to the tool call.\n            ctx = RenderContext(\n                idx=idx,\n                is_last=is_last_assistant,\n                prev_message=messages[idx - 1] if idx > 0 else None,\n            )\n            rendered_message = self.render_message(message, ctx)\n            header_chunk = rendered_message.header\n            output_chunks = rendered_message.output\n            if header_chunk:\n                chunks.append(header_chunk)\n            chunks.extend([x for x in output_chunks if x])\n\n        # Add generation prompt for new assistant message\n        gen_prompt = f\"<|im_assistant|>{role}<|im_middle|>\"\n        chunks.append(tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(gen_prompt)))\n        if prefill:\n            chunks.append(tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(prefill)))\n        return tinker.ModelInput(chunks=chunks)\n\n    def build_supervised_examples(\n        self,\n        messages: list[Message],\n        train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_TURN,\n    ) -> list[tuple[tinker.ModelInput, torch.Tensor]]:\n        \"\"\"\n        Build tokens and per-token weights for supervised fine-tuning. Since Kimi K2 renderer does not satisfy the extension property, this method is provided to return multiple examples in case we want to train on multiple assistant messages, potentially across multiple turns of user-assistant conversation.\n        \"\"\"\n\n        if (\n            train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE\n            or train_on_what == TrainOnWhat.LAST_ASSISTANT_TURN\n        ):\n            return [self.build_supervised_example(messages, train_on_what=train_on_what)]\n\n        # split the messages into turns by user messages\n        user_message_idxs = [\n            idx for idx, message in enumerate(messages) if message[\"role\"] == \"user\"\n        ]\n\n        supervised_examples: list[tuple[tinker.ModelInput, torch.Tensor]] = []\n\n        if train_on_what != TrainOnWhat.ALL_ASSISTANT_MESSAGES:\n            warnings.warn(\n                \"WARNING: Using train_on_what=ALL_MESSAGES/ALL_TOKENS/ALL_USER_AND_SYSTEM_MESSAGES/CUSTOMIZED with a renderer that \"\n                \"does not satisfy the extension property (has_extension_property=False). \"\n                \"The behavior is we apply the same `train_on_what` to all turns. This may not be the desired behavior.\",\n                UserWarning,\n                stacklevel=3,\n            )\n\n        # We separate the turns by user messages. The first turn is the messages before the second user message.\n        for user_message_idx in [*user_message_idxs[1:], len(messages)]:\n            current_messages = messages[:user_message_idx]\n            if train_on_what == TrainOnWhat.ALL_ASSISTANT_MESSAGES:\n                supervised_examples.append(\n                    self.build_supervised_example(\n                        current_messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_TURN\n                    )\n                )\n            else:\n                supervised_examples.append(\n                    self.build_supervised_example(current_messages, train_on_what=train_on_what)\n                )\n\n        return supervised_examples\n\n    def build_supervised_example(\n        self,\n        messages: list[Message],\n        train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_MESSAGE,\n    ) -> tuple[tinker.ModelInput, torch.Tensor]:\n        \"\"\"\n        Override to properly handle thinking preservation for the last assistant message.\n        Also ensures default system message is prepended if none is present.\n        \"\"\"\n        messages = self._ensure_system_message(messages)\n\n        # Kimi K2 hf template preserves the thinking of the assistant messages after the last non-tool-call assistant message.\n        # We do the same in general. However, we intentionally skip the last message (which differs from HF template behavior) since for a complete conversation,\n        # we would want to preserve the thinking of the last round of conversation between the user and the assistant (which could include multiple assistant messages and tool calls).\n        # This is because the trajectory would then be taken for SFT without losing all the thinking content.\n        last_assistant_idx = -1\n        for idx in range(len(messages) - 2, -1, -1):\n            if messages[idx][\"role\"] == \"assistant\" and not messages[idx].get(\"tool_calls\"):\n                last_assistant_idx = idx\n                break\n\n        model_input_chunks_weights: list[tuple[tinker.types.ModelInputChunk, float]] = []\n\n        for idx, message in enumerate(messages):\n            if train_on_what == TrainOnWhat.CUSTOMIZED:\n                assert \"trainable\" in message, (\n                    \"When using CUSTOMIZED train_on_what, each message must have a trainable field\"\n                )\n            else:\n                assert \"trainable\" not in message, (\n                    \"When using non-CUSTOMIZED train_on_what, each message must not have a trainable field\"\n                )\n\n            is_assistant = message[\"role\"] == \"assistant\"\n            is_last_message = idx == len(messages) - 1\n            is_user_or_system = message[\"role\"] in [\"user\", \"system\"]\n\n            # For Kimi K2, preserve thinking only for the suffix after the last non-tool-call assistant.\n            # If no such assistant exists, the suffix is the entire message list.\n            # Preserve thinking only for assistants after the last non-tool-call assistant.\n            is_last_assistant_turn = is_assistant and (\n                last_assistant_idx == -1 or idx > last_assistant_idx\n            )\n\n            is_last_assistant = is_assistant and is_last_message\n            ctx = RenderContext(\n                idx=idx,\n                is_last=is_last_assistant_turn,\n                prev_message=messages[idx - 1] if idx > 0 else None,\n            )\n            rendered_message = self.render_message(message, ctx)\n\n            header_part = rendered_message.header\n            output_parts = rendered_message.output\n\n            header_weight = int(train_on_what == TrainOnWhat.ALL_TOKENS)\n            if header_part:\n                model_input_chunks_weights += [(header_part, header_weight)]\n\n            # We include all assistant messages in the last round of assistant-tool interactions as the last assistant message.\n            match train_on_what:\n                case TrainOnWhat.LAST_ASSISTANT_MESSAGE:\n                    output_has_weight = is_last_assistant\n                case TrainOnWhat.LAST_ASSISTANT_TURN:\n                    output_has_weight = is_last_assistant_turn\n                case TrainOnWhat.ALL_ASSISTANT_MESSAGES:\n                    output_has_weight = is_assistant\n                case TrainOnWhat.ALL_MESSAGES:\n                    output_has_weight = True\n                case TrainOnWhat.ALL_TOKENS:\n                    output_has_weight = True\n                case TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES:\n                    output_has_weight = is_user_or_system\n                case TrainOnWhat.CUSTOMIZED:\n                    output_has_weight = message.get(\"trainable\", False)\n                case _:\n                    raise RendererError(f\"Unknown train_on_what: {train_on_what}\")\n\n            model_input_chunks_weights += [\n                (output_part, int(output_has_weight)) for output_part in output_parts if output_part\n            ]\n\n        weights_data = [w for chunk, w in model_input_chunks_weights for _ in range(chunk.length)]\n        weights_tensor = torch.tensor(weights_data)\n\n        model_input_chunks = [chunk for chunk, _ in model_input_chunks_weights]\n        return tinker.ModelInput(chunks=model_input_chunks), weights_tensor\n\n    @property\n    def _end_message_token(self) -> int:\n        tokens = self.tokenizer.encode(\"<|im_end|>\")\n        assert len(tokens) == 1, f\"Expected single token for <|im_end|>, got {len(tokens)}\"\n        return tokens[0]\n\n    def get_stop_sequences(self) -> list[int]:\n        return [self._end_message_token]\n\n    def parse_response(self, response: list[int]) -> tuple[Message, bool]:\n        response = self._normalize_response_tokens(response)\n        assistant_message, parse_success = parse_response_for_stop_token(\n            response, self.tokenizer, self._end_message_token\n        )\n        if not parse_success:\n            return assistant_message, False\n\n        content = assistant_message[\"content\"]\n        assert isinstance(content, str)\n\n        # Handle tool calls if present\n        text_content, tool_section = _split_tool_calls_section(content)\n        if tool_section is not None:\n            tool_calls, unparsed_tool_calls = _parse_tool_calls_section(tool_section)\n            if tool_calls:\n                assistant_message[\"tool_calls\"] = tool_calls\n            if unparsed_tool_calls:\n                assistant_message[\"unparsed_tool_calls\"] = unparsed_tool_calls\n\n        content_parts = parse_think_blocks(text_content)\n        assistant_message[\"content\"] = content_parts if content_parts is not None else text_content\n\n        return assistant_message, True\n\n    def _parse_response_for_streaming(self, response: list[int]) -> tuple[Message, bool]:\n        \"\"\"Parse response for streaming, always applying full content parsing.\n\n        Unlike parse_response which short-circuits on missing stop token,\n        this always parses think blocks and tool calls from the content.\n        This matches the original KimiK2StreamingParser.finish() behavior\n        where content parsing was applied regardless of stop token presence.\n        \"\"\"\n        message, parse_success = parse_response_for_stop_token(\n            response, self.tokenizer, self._end_message_token\n        )\n\n        content = message.get(\"content\", \"\")\n        if isinstance(content, str):\n            text_content, tool_section = _split_tool_calls_section(content)\n            if tool_section is not None:\n                tool_calls, unparsed_tool_calls = _parse_tool_calls_section(tool_section)\n                if tool_calls:\n                    message[\"tool_calls\"] = tool_calls\n                if unparsed_tool_calls:\n                    message[\"unparsed_tool_calls\"] = unparsed_tool_calls\n\n            content_parts = parse_think_blocks(text_content)\n            message[\"content\"] = content_parts if content_parts is not None else text_content\n\n        return message, parse_success\n\n    def to_openai_message(self, message: Message) -> dict:\n        \"\"\"Convert a Message to OpenAI API format with reasoning_content for thinking.\n\n        Kimi K2's HF template explicitly expects reasoning_content as a separate field.\n        \"\"\"\n        result: dict = {\"role\": message[\"role\"]}\n\n        content = message[\"content\"]\n        if isinstance(content, str):\n            result[\"content\"] = content\n        else:\n            # Extract thinking into reasoning_content, keep text in content\n            thinking_parts = []\n            text_parts = []\n            for p in content:\n                if p[\"type\"] == \"thinking\":\n                    thinking_parts.append(p[\"thinking\"])\n                elif p[\"type\"] == \"text\":\n                    text_parts.append(p[\"text\"])\n\n            result[\"content\"] = \"\".join(text_parts)\n            if thinking_parts:\n                result[\"reasoning_content\"] = \"\".join(thinking_parts)\n\n        # Handle tool_calls\n        if \"tool_calls\" in message and message[\"tool_calls\"]:  # noqa: RUF019\n            result[\"tool_calls\"] = [\n                {\n                    \"type\": \"function\",\n                    \"id\": tc.id,\n                    \"function\": {\n                        \"name\": tc.function.name,\n                        \"arguments\": tc.function.arguments,\n                    },\n                }\n                for tc in message[\"tool_calls\"]\n            ]\n\n        # Handle tool response fields\n        if message[\"role\"] == \"tool\":\n            if \"tool_call_id\" in message:\n                result[\"tool_call_id\"] = message[\"tool_call_id\"]\n            if \"name\" in message:\n                result[\"name\"] = message[\"name\"]\n\n        return result\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        \"\"\"Create system messages with Kimi K2 tool specifications.\n\n        Per the HuggingFace chat template, Kimi K2 places the tool_declare message\n        BEFORE the regular system message. The tool_declare payload expects the\n        OpenAI-style tool schema ({\"type\":\"function\",\"function\":{...}}).\n        If no system_prompt is provided, uses the default system prompt to match\n        HuggingFace chat template behavior.\n\n        Reference: https://huggingface.co/moonshotai/Kimi-K2-Thinking/blob/main/chat_template.jinja\n        \"\"\"\n        messages: list[Message] = []\n\n        # Tool declaration message comes first (per HF chat template)\n        if tools:\n            tools_payload = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n            # Use sort_keys=True since Kimi K2 sorts keys alphabetically with its own custom apply_chat_template function\n            tools_json = json.dumps(tools_payload, separators=(\",\", \":\"), sort_keys=True)\n            messages.append(Message(role=\"tool_declare\", content=tools_json))\n\n        # Regular system message second (use default if none provided)\n        actual_system_prompt = system_prompt if system_prompt else self.DEFAULT_SYSTEM_PROMPT\n        messages.append(Message(role=\"system\", content=actual_system_prompt))\n\n        return messages\n"
  },
  {
    "path": "tinker_cookbook/renderers/kimi_k25.py",
    "content": "\"\"\"Renderer for Moonshot AI's Kimi K2.5 models.\"\"\"\n\nfrom typing import cast\n\nimport tinker\n\nfrom tinker_cookbook.exceptions import RendererError\nfrom tinker_cookbook.image_processing_utils import ImageProcessor\nfrom tinker_cookbook.renderers.base import (\n    ContentPart,\n    ImageProcessorProtocol,\n    Message,\n    Role,\n    ToolSpec,\n    image_to_chunk,\n)\nfrom tinker_cookbook.renderers.kimi_k2 import KimiK2Renderer\nfrom tinker_cookbook.renderers.kimi_k2_5_tool_declaration_ts import encode_tools_to_typescript_style\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n\nclass KimiK25Renderer(KimiK2Renderer):\n    \"\"\"\n    Renderer for Kimi K2.5 with thinking enabled (default).\n\n    Key differences from KimiK2Renderer:\n    1. Generation prompt prefill: Appends `<think>` (open tag) to enable thinking mode\n    2. Tool declarations: Uses TypeScript-style format instead of JSON\n\n    Format:\n        <|im_system|>system<|im_middle|>You are Kimi...<|im_end|>\n        <|im_user|>user<|im_middle|>Hello<|im_end|>\n        <|im_assistant|>assistant<|im_middle|><think>\n\n    Historical assistant messages use empty <think></think> blocks (inherited from K2),\n    while the generation prompt adds an open <think> tag to enable thinking.\n    \"\"\"\n\n    image_processor: ImageProcessor | None\n    _think_open_token: int\n    _think_close_token: int\n\n    def __init__(\n        self,\n        tokenizer: Tokenizer,\n        image_processor: ImageProcessor | None = None,\n        strip_thinking_from_history: bool = True,\n    ):\n        super().__init__(tokenizer, strip_thinking_from_history=strip_thinking_from_history)\n        self.image_processor = image_processor\n        (self._think_open_token,) = self.tokenizer.encode(\"<think>\", add_special_tokens=False)\n        (self._think_close_token,) = self.tokenizer.encode(\"</think>\", add_special_tokens=False)\n\n    def _encode_multipart_content(self, content: list[ContentPart]) -> list[tinker.ModelInputChunk]:\n        chunks = []\n        for part in content:\n            if part[\"type\"] == \"text\":\n                chunks.append(\n                    tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(part[\"text\"]))\n                )\n            elif part[\"type\"] == \"image\":\n                assert self.image_processor is not None, (\n                    \"KimiK25Renderer must be initialized with an image processor in order to support image content parts\"\n                )\n                chunks.append(\n                    tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(self._image_prefix))\n                )\n                chunks.append(\n                    image_to_chunk(\n                        part[\"image\"], cast(ImageProcessorProtocol, self.image_processor)\n                    )\n                )\n                chunks.append(\n                    tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(self._image_suffix))\n                )\n            else:\n                raise RendererError(f\"Unsupported content type: {part['type']}\")\n        return chunks\n\n    @property\n    def _image_prefix(self) -> str:\n        return \"<|media_begin|>image<|media_content|>\"\n\n    @property\n    def _image_suffix(self) -> str:\n        return \"<|media_end|>\\n\"\n\n    def build_generation_prompt(\n        self, messages: list[Message], role: Role = \"assistant\", prefill: str | None = None\n    ) -> tinker.ModelInput:\n        \"\"\"Build generation prompt with <think> prefill for thinking mode.\"\"\"\n        # If no prefill specified, use <think> to enable thinking\n        if prefill is None:\n            prefill = \"<think>\"\n        return super().build_generation_prompt(messages, role=role, prefill=prefill)\n\n    def _normalize_response_tokens(self, response: list[int]) -> list[int]:\n        \"\"\"Restore the synthetic <think> prefill before parsing sampled tokens.\"\"\"\n        if (\n            response\n            and response[0] != self._think_open_token\n            and self._think_close_token in response\n        ):\n            return [self._think_open_token, *response]\n        return response\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        \"\"\"Create system messages with TypeScript-style tool specifications.\n\n        Per the HuggingFace chat template, Kimi K2.5 uses TypeScript-style tool\n        declarations instead of JSON format. The tool_declare message comes BEFORE\n        the regular system message.\n\n        Reference: kimi-k2.5-hf-tokenizer/chat_template.jinja\n        \"\"\"\n        messages: list[Message] = []\n\n        # Tool declaration message comes first (per HF chat template)\n        if tools:\n            tools_payload = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n            tools_ts_str = encode_tools_to_typescript_style(tools_payload)\n            messages.append(Message(role=\"tool_declare\", content=tools_ts_str))\n\n        # Regular system message second (use default if none provided)\n        actual_system_prompt = system_prompt if system_prompt else self.DEFAULT_SYSTEM_PROMPT\n        messages.append(Message(role=\"system\", content=actual_system_prompt))\n\n        return messages\n\n\nclass KimiK25DisableThinkingRenderer(KimiK25Renderer):\n    \"\"\"\n    Renderer for Kimi K2.5 with thinking disabled.\n\n    Uses `<think></think>` prefill instead of `<think>` to disable thinking mode.\n\n    Format:\n        <|im_system|>system<|im_middle|>You are Kimi...<|im_end|>\n        <|im_user|>user<|im_middle|>Hello<|im_end|>\n        <|im_assistant|>assistant<|im_middle|><think></think>\n    \"\"\"\n\n    def build_generation_prompt(\n        self, messages: list[Message], role: Role = \"assistant\", prefill: str | None = None\n    ) -> tinker.ModelInput:\n        \"\"\"Build generation prompt with <think></think> prefill to disable thinking.\"\"\"\n        # If no prefill specified, use <think></think> to disable thinking\n        if prefill is None:\n            prefill = \"<think></think>\"\n        return super(KimiK25Renderer, self).build_generation_prompt(\n            messages, role=role, prefill=prefill\n        )\n"
  },
  {
    "path": "tinker_cookbook/renderers/kimi_k25_test.py",
    "content": "\"\"\"\nTests for Kimi K2.5 renderer.\n\nTests verify that the KimiK25Renderer produces correct output:\n1. Generation prompt includes `<think>` prefill (thinking enabled)\n2. Disable-thinking variant uses `<think></think>` prefill\n3. TypeScript-style tool declarations\n4. HF template compatibility for both build_generation_prompt and build_supervised_example\n\"\"\"\n\nfrom typing import cast\n\nimport pytest\nimport tinker\nfrom PIL import Image\n\nfrom tinker_cookbook.image_processing_utils import get_image_processor\nfrom tinker_cookbook.renderers import (\n    Message,\n    StreamingTextDelta,\n    StreamingThinkingDelta,\n    TextPart,\n    ThinkingPart,\n    ToolCall,\n    ToolSpec,\n    get_renderer,\n)\nfrom tinker_cookbook.renderers.kimi_k2_5_tool_declaration_ts import encode_tools_to_typescript_style\nfrom tinker_cookbook.renderers.testing_utils import extract_token_ids\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nKIMI_K25_MODEL = \"moonshotai/Kimi-K2.5\"\n\n\n# =============================================================================\n# Test Fixtures\n# =============================================================================\n\n\n@pytest.fixture(scope=\"module\")\ndef kimi_tokenizer():\n    \"\"\"Get the Kimi K2.5 tokenizer (cached per module).\"\"\"\n    try:\n        return get_tokenizer(KIMI_K25_MODEL)\n    except ModuleNotFoundError as e:\n        if \"Kimi-K2\" in str(e):\n            pytest.skip(f\"K2.5 tokenizer has HF module import bug: {e}\")\n        raise\n\n\n@pytest.fixture(scope=\"module\")\ndef kimi_renderer(kimi_tokenizer):\n    \"\"\"Get the Kimi K2.5 renderer (cached per module).\"\"\"\n    return get_renderer(\"kimi_k25\", kimi_tokenizer)\n\n\n@pytest.fixture(scope=\"module\")\ndef kimi_renderer_disable_thinking(kimi_tokenizer):\n    \"\"\"Get the Kimi K2.5 disable-thinking renderer (cached per module).\"\"\"\n    return get_renderer(\"kimi_k25_disable_thinking\", kimi_tokenizer)\n\n\n@pytest.fixture(scope=\"module\")\ndef hf_generation_prompt_length(kimi_tokenizer):\n    \"\"\"Calculate the number of tokens in the HF generation prompt (cached per module).\n\n    Uses a dummy conversation to find the difference between with/without generation prompt.\n    This is constant regardless of conversation content.\n    \"\"\"\n    dummy_msgs = [{\"role\": \"user\", \"content\": \"hi\"}]\n    tokens_with = extract_token_ids(\n        kimi_tokenizer.apply_chat_template(\n            dummy_msgs, add_generation_prompt=True, tokenize=True, thinking=True\n        )\n    )\n    tokens_without = extract_token_ids(\n        kimi_tokenizer.apply_chat_template(\n            dummy_msgs, add_generation_prompt=False, tokenize=True, thinking=True\n        )\n    )\n    return len(tokens_with) - len(tokens_without)\n\n\ndef get_hf_tokens(\n    tokenizer, hf_messages, gen_prompt_length: int, tools=None, for_generation: bool = True\n) -> list[int]:\n    \"\"\"Get HF tokens for generation or supervised mode.\n\n    For supervised mode, slices off the generation prompt tokens.\n    \"\"\"\n    tokens = extract_token_ids(\n        tokenizer.apply_chat_template(\n            hf_messages,\n            tools=tools,\n            add_generation_prompt=True,\n            tokenize=True,\n            thinking=True,\n        )\n    )\n\n    if for_generation:\n        return tokens\n    return tokens[:-gen_prompt_length] if gen_prompt_length else tokens\n\n\n# =============================================================================\n# Helpers\n# =============================================================================\n\n\ndef get_tool_spec() -> ToolSpec:\n    \"\"\"Sample tool specification for testing.\"\"\"\n    return ToolSpec(\n        name=\"get_weather\",\n        description=\"Get the current weather for a location\",\n        parameters={\n            \"type\": \"object\",\n            \"properties\": {\n                \"location\": {\n                    \"type\": \"string\",\n                    \"description\": \"The city and state, e.g. San Francisco, CA\",\n                },\n                \"unit\": {\n                    \"type\": \"string\",\n                    \"enum\": [\"celsius\", \"fahrenheit\"],\n                    \"description\": \"Temperature unit\",\n                },\n            },\n            \"required\": [\"location\"],\n        },\n    )\n\n\n# =============================================================================\n# Test Conversations\n# =============================================================================\n\n\ndef get_basic_conversation_for_generation() -> list[Message]:\n    \"\"\"3-turn conversation ending with user message (for generation).\"\"\"\n    return [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        {\"role\": \"assistant\", \"content\": \"I'm fine, thank you!\"},\n        {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n    ]\n\n\ndef get_basic_conversation_for_supervised() -> list[Message]:\n    \"\"\"2-turn conversation ending with assistant (for supervised).\"\"\"\n    return [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        {\"role\": \"assistant\", \"content\": \"I'm fine, thank you!\"},\n    ]\n\n\ndef get_tool_call_conversation_for_generation() -> tuple[list[Message], list[ToolSpec]]:\n    \"\"\"Conversation with tool call, ending ready for generation.\"\"\"\n    tools = [get_tool_spec()]\n    tool_call = ToolCall(\n        id=\"functions.get_weather:0\",\n        function=ToolCall.FunctionBody(\n            name=\"get_weather\",\n            arguments='{\"location\": \"New York, NY\"}',\n        ),\n    )\n    messages: list[Message] = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\"type\": \"thinking\", \"thinking\": \"I need to check the weather in New York City.\"},\n                {\"type\": \"text\", \"text\": \"\"},\n            ],\n            \"tool_calls\": [tool_call],\n        },\n        {\n            \"role\": \"tool\",\n            \"name\": \"get_weather\",\n            \"tool_call_id\": \"functions.get_weather:0\",\n            \"content\": '{\"temperature\": 72, \"condition\": \"sunny\"}',\n        },\n    ]\n    return messages, tools\n\n\ndef get_tool_call_conversation_for_supervised() -> tuple[list[Message], list[ToolSpec]]:\n    \"\"\"Complete tool call conversation with final assistant response (for supervised).\"\"\"\n    tools = [get_tool_spec()]\n    tool_call = ToolCall(\n        id=\"functions.get_weather:0\",\n        function=ToolCall.FunctionBody(\n            name=\"get_weather\",\n            arguments='{\"location\": \"New York, NY\"}',\n        ),\n    )\n    messages: list[Message] = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\"type\": \"thinking\", \"thinking\": \"I need to check the weather in New York City.\"},\n                {\"type\": \"text\", \"text\": \"\"},\n            ],\n            \"tool_calls\": [tool_call],\n        },\n        {\n            \"role\": \"tool\",\n            \"name\": \"get_weather\",\n            \"tool_call_id\": \"functions.get_weather:0\",\n            \"content\": '{\"temperature\": 72, \"condition\": \"sunny\"}',\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\"type\": \"thinking\", \"thinking\": \"The weather data shows 72F and sunny.\"},\n                {\"type\": \"text\", \"text\": \"The weather in NYC is 72°F and sunny.\"},\n            ],\n        },\n    ]\n    return messages, tools\n\n\ndef get_multi_tool_call_conversation_for_generation() -> tuple[list[Message], list[ToolSpec]]:\n    \"\"\"Conversation with multiple tool calls in one message.\"\"\"\n    tools = [get_tool_spec()]\n    tool_calls = [\n        ToolCall(\n            id=\"functions.get_weather:0\",\n            function=ToolCall.FunctionBody(\n                name=\"get_weather\",\n                arguments='{\"location\": \"New York, NY\"}',\n            ),\n        ),\n        ToolCall(\n            id=\"functions.get_weather:1\",\n            function=ToolCall.FunctionBody(\n                name=\"get_weather\",\n                arguments='{\"location\": \"Los Angeles, CA\"}',\n            ),\n        ),\n    ]\n    messages: list[Message] = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"What's the weather in NYC and LA?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\"type\": \"thinking\", \"thinking\": \"I'll check the weather in both cities.\"},\n                {\"type\": \"text\", \"text\": \"\"},\n            ],\n            \"tool_calls\": tool_calls,\n        },\n        {\n            \"role\": \"tool\",\n            \"name\": \"get_weather\",\n            \"tool_call_id\": \"functions.get_weather:0\",\n            \"content\": '{\"temperature\": 72, \"condition\": \"sunny\"}',\n        },\n        {\n            \"role\": \"tool\",\n            \"name\": \"get_weather\",\n            \"tool_call_id\": \"functions.get_weather:1\",\n            \"content\": '{\"temperature\": 85, \"condition\": \"clear\"}',\n        },\n    ]\n    return messages, tools\n\n\ndef get_multi_step_tool_conversation_for_generation() -> tuple[list[Message], list[ToolSpec]]:\n    \"\"\"Multi-step tool calling: multiple rounds of tool calls.\"\"\"\n    tools = [get_tool_spec()]\n    messages: list[Message] = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"Compare the weather in NYC and LA.\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\"type\": \"thinking\", \"thinking\": \"Let me check NYC weather first.\"},\n                {\"type\": \"text\", \"text\": \"\"},\n            ],\n            \"tool_calls\": [\n                ToolCall(\n                    id=\"functions.get_weather:0\",\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"New York, NY\"}',\n                    ),\n                ),\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"name\": \"get_weather\",\n            \"tool_call_id\": \"functions.get_weather:0\",\n            \"content\": '{\"temperature\": 72, \"condition\": \"sunny\"}',\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\"type\": \"thinking\", \"thinking\": \"Now let me check LA weather.\"},\n                {\"type\": \"text\", \"text\": \"\"},\n            ],\n            \"tool_calls\": [\n                ToolCall(\n                    id=\"functions.get_weather:1\",\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"Los Angeles, CA\"}',\n                    ),\n                ),\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"name\": \"get_weather\",\n            \"tool_call_id\": \"functions.get_weather:1\",\n            \"content\": '{\"temperature\": 85, \"condition\": \"clear\"}',\n        },\n    ]\n    return messages, tools\n\n\n# =============================================================================\n# TypeScript Tool Declaration Tests\n# =============================================================================\n\n\ndef test_typescript_tool_declaration_basic():\n    \"\"\"Test basic TypeScript tool declaration generation.\"\"\"\n    tools = [{\"type\": \"function\", \"function\": get_tool_spec()}]\n    ts_str = encode_tools_to_typescript_style(tools)\n\n    assert \"# Tools\" in ts_str\n    assert \"## functions\" in ts_str\n    assert \"namespace functions {\" in ts_str\n    assert \"get_weather\" in ts_str\n    assert \"type get_weather = (_:\" in ts_str\n    assert \"location\" in ts_str\n    assert \"string\" in ts_str\n\n\ndef test_typescript_tool_declaration_with_enum():\n    \"\"\"Test TypeScript declaration includes enum values.\"\"\"\n    tools = [{\"type\": \"function\", \"function\": get_tool_spec()}]\n    ts_str = encode_tools_to_typescript_style(tools)\n\n    assert '\"celsius\"' in ts_str or \"'celsius'\" in ts_str\n    assert '\"fahrenheit\"' in ts_str or \"'fahrenheit'\" in ts_str\n\n\ndef test_typescript_tool_declaration_description():\n    \"\"\"Test TypeScript declaration includes descriptions as comments.\"\"\"\n    tools = [{\"type\": \"function\", \"function\": get_tool_spec()}]\n    ts_str = encode_tools_to_typescript_style(tools)\n\n    assert \"// Get the current weather\" in ts_str\n\n\ndef test_typescript_tool_declaration_empty():\n    \"\"\"Test TypeScript declaration with empty tools list.\"\"\"\n    ts_str = encode_tools_to_typescript_style([])\n    assert ts_str == \"\"\n\n\ndef test_typescript_tool_declaration_multiple_tools():\n    \"\"\"Test TypeScript declaration with multiple tools.\"\"\"\n    tools = [\n        {\n            \"type\": \"function\",\n            \"function\": ToolSpec(\n                name=\"get_weather\",\n                description=\"Get the current weather for a location\",\n                parameters={\n                    \"type\": \"object\",\n                    \"properties\": {\"location\": {\"type\": \"string\"}},\n                    \"required\": [\"location\"],\n                },\n            ),\n        },\n        {\n            \"type\": \"function\",\n            \"function\": ToolSpec(\n                name=\"search_web\",\n                description=\"Search the web for information\",\n                parameters={\n                    \"type\": \"object\",\n                    \"properties\": {\"query\": {\"type\": \"string\"}},\n                    \"required\": [\"query\"],\n                },\n            ),\n        },\n    ]\n    ts_str = encode_tools_to_typescript_style(tools)\n\n    assert \"type get_weather = (_:\" in ts_str\n    assert \"type search_web = (_:\" in ts_str\n    assert \"// Get the current weather\" in ts_str\n    assert \"// Search the web\" in ts_str\n\n\n# =============================================================================\n# Generation Prompt Prefill Tests (specific to generation)\n# =============================================================================\n\n\ndef test_kimi_k25_generation_prompt_has_think_prefill(kimi_tokenizer, kimi_renderer):\n    \"\"\"Test that KimiK25Renderer adds <think> prefill for generation.\"\"\"\n    messages = get_basic_conversation_for_generation()\n    gen_prompt = kimi_renderer.build_generation_prompt(messages)\n    decoded = kimi_tokenizer.decode(gen_prompt.to_ints())\n\n    assert decoded.endswith(\"<|im_assistant|>assistant<|im_middle|><think>\")\n\n\ndef test_kimi_k25_disable_thinking_generation_prompt(\n    kimi_tokenizer, kimi_renderer_disable_thinking\n):\n    \"\"\"Test that KimiK25DisableThinkingRenderer adds <think></think> prefill.\"\"\"\n    messages = get_basic_conversation_for_generation()\n    gen_prompt = kimi_renderer_disable_thinking.build_generation_prompt(messages)\n    decoded = kimi_tokenizer.decode(gen_prompt.to_ints())\n\n    assert decoded.endswith(\"<|im_assistant|>assistant<|im_middle|><think></think>\")\n\n\ndef test_kimi_k25_custom_prefill_overrides_default(kimi_tokenizer, kimi_renderer):\n    \"\"\"Test that custom prefill overrides the default <think> prefill.\"\"\"\n    messages = get_basic_conversation_for_generation()\n    custom_prefill = \"Custom response: \"\n    gen_prompt = kimi_renderer.build_generation_prompt(messages, prefill=custom_prefill)\n    decoded = kimi_tokenizer.decode(gen_prompt.to_ints())\n\n    assert decoded.endswith(custom_prefill)\n    assert not decoded.endswith(\"<think>\")\n\n\n# =============================================================================\n# HF Template Compatibility Tests - Parametrized for generation and supervised\n# =============================================================================\n\n\ndef test_kimi_k25_basic_conversation_matches_hf(\n    kimi_tokenizer, kimi_renderer, hf_generation_prompt_length\n):\n    \"\"\"Test basic conversation generation matches HF template.\"\"\"\n    messages = get_basic_conversation_for_generation()\n    cookbook_tokens = kimi_renderer.build_generation_prompt(messages).to_ints()\n\n    hf_messages = [kimi_renderer.to_openai_message(m) for m in messages]\n    hf_tokens = get_hf_tokens(\n        kimi_tokenizer, hf_messages, hf_generation_prompt_length, for_generation=True\n    )\n\n    assert cookbook_tokens == hf_tokens, (\n        f\"Cookbook string: {kimi_tokenizer.decode(cookbook_tokens)}\\n\"\n        f\"HF string: {kimi_tokenizer.decode(hf_tokens)}\"\n    )\n\n\ndef test_kimi_k25_tool_call_conversation_matches_hf(\n    kimi_tokenizer, kimi_renderer, hf_generation_prompt_length\n):\n    \"\"\"Test tool call conversation generation matches HF template.\"\"\"\n    messages, tools = get_tool_call_conversation_for_generation()\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n\n    prefix_messages = kimi_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=\"You are a helpful assistant.\"\n    )\n    prefix_messages = [m for m in prefix_messages if m[\"role\"] == \"tool_declare\"]\n    full_messages = prefix_messages + messages\n\n    cookbook_tokens = kimi_renderer.build_generation_prompt(full_messages).to_ints()\n\n    hf_messages = [kimi_renderer.to_openai_message(m) for m in messages]\n    hf_tokens = get_hf_tokens(\n        kimi_tokenizer,\n        hf_messages,\n        hf_generation_prompt_length,\n        tools=openai_tools,\n        for_generation=True,\n    )\n\n    assert cookbook_tokens == hf_tokens, (\n        f\"Cookbook string: {kimi_tokenizer.decode(cookbook_tokens)}\\n\"\n        f\"HF string: {kimi_tokenizer.decode(hf_tokens)}\"\n    )\n\n\ndef test_kimi_k25_multi_tool_calls_matches_hf(\n    kimi_tokenizer, kimi_renderer, hf_generation_prompt_length\n):\n    \"\"\"Test multiple tool calls in one message matches HF template.\"\"\"\n    messages, tools = get_multi_tool_call_conversation_for_generation()\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n\n    prefix_messages = kimi_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=\"You are a helpful assistant.\"\n    )\n    prefix_messages = [m for m in prefix_messages if m[\"role\"] == \"tool_declare\"]\n    full_messages = prefix_messages + messages\n\n    cookbook_tokens = kimi_renderer.build_generation_prompt(full_messages).to_ints()\n\n    hf_messages = [kimi_renderer.to_openai_message(m) for m in messages]\n    hf_tokens = get_hf_tokens(\n        kimi_tokenizer,\n        hf_messages,\n        hf_generation_prompt_length,\n        tools=openai_tools,\n        for_generation=True,\n    )\n\n    assert cookbook_tokens == hf_tokens, (\n        f\"Cookbook string: {kimi_tokenizer.decode(cookbook_tokens)}\\n\"\n        f\"HF string: {kimi_tokenizer.decode(hf_tokens)}\"\n    )\n\n\ndef test_kimi_k25_multi_step_tool_calls_matches_hf(\n    kimi_tokenizer, kimi_renderer, hf_generation_prompt_length\n):\n    \"\"\"Test multi-step tool calling matches HF template.\"\"\"\n    messages, tools = get_multi_step_tool_conversation_for_generation()\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n\n    prefix_messages = kimi_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=\"You are a helpful assistant.\"\n    )\n    prefix_messages = [m for m in prefix_messages if m[\"role\"] == \"tool_declare\"]\n    full_messages = prefix_messages + messages\n\n    cookbook_tokens = kimi_renderer.build_generation_prompt(full_messages).to_ints()\n\n    hf_messages = [kimi_renderer.to_openai_message(m) for m in messages]\n    hf_tokens = get_hf_tokens(\n        kimi_tokenizer,\n        hf_messages,\n        hf_generation_prompt_length,\n        tools=openai_tools,\n        for_generation=True,\n    )\n\n    assert cookbook_tokens == hf_tokens, (\n        f\"Cookbook string: {kimi_tokenizer.decode(cookbook_tokens)}\\n\"\n        f\"HF string: {kimi_tokenizer.decode(hf_tokens)}\"\n    )\n\n\n# =============================================================================\n# Tool Declaration Format Tests\n# =============================================================================\n\n\ndef test_kimi_k25_tool_declaration_is_typescript(kimi_renderer):\n    \"\"\"Test that K2.5 uses TypeScript-style tool declarations.\"\"\"\n    tools = [get_tool_spec()]\n    prefix_messages = kimi_renderer.create_conversation_prefix_with_tools(tools)\n\n    assert len(prefix_messages) >= 1\n    assert prefix_messages[0][\"role\"] == \"tool_declare\"\n\n    tool_content = prefix_messages[0][\"content\"]\n    assert isinstance(tool_content, str)\n\n    # Should be TypeScript style, not JSON\n    assert \"# Tools\" in tool_content\n    assert \"namespace functions\" in tool_content\n    assert \"type get_weather\" in tool_content\n    assert '\"type\":\"function\"' not in tool_content\n\n\n@pytest.mark.parametrize(\"build_mode\", [\"generation\", \"supervised\"])\ndef test_kimi_k25_tool_declaration_matches_hf(\n    build_mode: str, kimi_tokenizer, kimi_renderer, hf_generation_prompt_length\n):\n    \"\"\"Test that tool declarations match HF template output.\"\"\"\n    tools = [get_tool_spec()]\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n\n    prefix_messages = kimi_renderer.create_conversation_prefix_with_tools(tools)\n    user_msg = Message(role=\"user\", content=\"What's the weather in NYC?\")\n\n    if build_mode == \"generation\":\n        full_messages = prefix_messages + [user_msg]\n        cookbook_tokens = kimi_renderer.build_generation_prompt(full_messages).to_ints()\n    else:\n        assistant_msg = Message(role=\"assistant\", content=\"Let me check that for you.\")\n        full_messages = prefix_messages + [user_msg, assistant_msg]\n        model_input, _ = kimi_renderer.build_supervised_example(full_messages)\n        cookbook_tokens = model_input.to_ints()\n\n    hf_messages = [\n        {\"role\": \"system\", \"content\": kimi_renderer.DEFAULT_SYSTEM_PROMPT},\n        {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n    ]\n    if build_mode == \"supervised\":\n        hf_messages.append({\"role\": \"assistant\", \"content\": \"Let me check that for you.\"})\n\n    hf_tokens = get_hf_tokens(\n        kimi_tokenizer,\n        hf_messages,\n        hf_generation_prompt_length,\n        tools=openai_tools,\n        for_generation=(build_mode == \"generation\"),\n    )\n\n    assert cookbook_tokens == hf_tokens, (\n        f\"Mode: {build_mode}\\n\"\n        f\"Cookbook string: {kimi_tokenizer.decode(cookbook_tokens)}\\n\"\n        f\"HF string: {kimi_tokenizer.decode(hf_tokens)}\"\n    )\n\n\n# =============================================================================\n# Thinking Content Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\"build_mode\", [\"generation\", \"supervised\"])\ndef test_kimi_k25_thinking_preserved_in_suffix(build_mode: str, kimi_tokenizer, kimi_renderer):\n    \"\"\"Test that thinking is preserved for messages in the suffix (after last non-tool-call assistant).\"\"\"\n    # For supervised, thinking in last assistant should be preserved\n    # For generation with tool calls, thinking in tool-calling assistants should be preserved\n    if build_mode == \"supervised\":\n        messages: list[Message] = [\n            {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n            {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n            {\n                \"role\": \"assistant\",\n                \"content\": [\n                    {\"type\": \"thinking\", \"thinking\": \"Let me calculate. 2+2=4.\"},\n                    {\"type\": \"text\", \"text\": \"The answer is 4.\"},\n                ],\n            },\n        ]\n        model_input, _ = kimi_renderer.build_supervised_example(messages)\n        decoded = kimi_tokenizer.decode(model_input.to_ints())\n    else:\n        # Generation with tool calls - thinking should be preserved\n        messages, tools = get_tool_call_conversation_for_generation()\n        prefix_messages = kimi_renderer.create_conversation_prefix_with_tools(\n            tools, system_prompt=\"You are a helpful assistant.\"\n        )\n        prefix_messages = [m for m in prefix_messages if m[\"role\"] == \"tool_declare\"]\n        full_messages = prefix_messages + messages\n        gen_prompt = kimi_renderer.build_generation_prompt(full_messages)\n        decoded = kimi_tokenizer.decode(gen_prompt.to_ints())\n\n    # Thinking should be preserved\n    if build_mode == \"supervised\":\n        assert \"<think>Let me calculate. 2+2=4.</think>\" in decoded\n    else:\n        assert \"<think>I need to check the weather in New York City.</think>\" in decoded\n\n\n@pytest.mark.parametrize(\"build_mode\", [\"generation\", \"supervised\"])\ndef test_kimi_k25_thinking_stripped_in_history(build_mode: str, kimi_tokenizer, kimi_renderer):\n    \"\"\"Test that thinking is stripped for historical messages (before last non-tool-call assistant).\"\"\"\n    # Conversation with historical assistant message followed by more turns\n    messages: list[Message] = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\"type\": \"thinking\", \"thinking\": \"HISTORICAL_THINKING_SHOULD_BE_STRIPPED\"},\n                {\"type\": \"text\", \"text\": \"The answer is 4.\"},\n            ],\n        },\n        {\"role\": \"user\", \"content\": \"What is 3+3?\"},\n    ]\n\n    if build_mode == \"supervised\":\n        messages.append(\n            {\n                \"role\": \"assistant\",\n                \"content\": [\n                    {\"type\": \"thinking\", \"thinking\": \"SUFFIX_THINKING_PRESERVED\"},\n                    {\"type\": \"text\", \"text\": \"The answer is 6.\"},\n                ],\n            }\n        )\n        model_input, _ = kimi_renderer.build_supervised_example(messages)\n        decoded = kimi_tokenizer.decode(model_input.to_ints())\n    else:\n        gen_prompt = kimi_renderer.build_generation_prompt(messages)\n        decoded = kimi_tokenizer.decode(gen_prompt.to_ints())\n\n    # Historical thinking should be stripped\n    assert \"HISTORICAL_THINKING_SHOULD_BE_STRIPPED\" not in decoded\n    assert \"<think></think>The answer is 4.\" in decoded\n\n    # Suffix thinking should be preserved (only for supervised)\n    if build_mode == \"supervised\":\n        assert \"SUFFIX_THINKING_PRESERVED\" in decoded\n\n\n# =============================================================================\n# EOT Token Tests\n# =============================================================================\n\n\ndef test_kimi_k25_eot_parsing(kimi_tokenizer, kimi_renderer):\n    \"\"\"Test EOT token parsing for K2.5 renderer.\"\"\"\n    # Test with EOT token\n    test_response = \"The answer is 42.<|im_end|>\"\n    response_tokens = kimi_tokenizer.encode(test_response)\n\n    message, format_correct = kimi_renderer.parse_response(response_tokens)\n    assert message[\"role\"] == \"assistant\"\n    assert message[\"content\"] == \"The answer is 42.\"\n    assert format_correct is True\n\n    # Test without EOT token\n    test_response_no_eot = \"The answer is 42.\"\n    response_tokens_no_eot = kimi_tokenizer.encode(test_response_no_eot)\n\n    message, format_correct = kimi_renderer.parse_response(response_tokens_no_eot)\n    assert message[\"role\"] == \"assistant\"\n    assert message[\"content\"] == \"The answer is 42.\"\n    assert format_correct is False\n\n\ndef test_kimi_k25_parse_response_restores_prefilled_think_tag(kimi_tokenizer, kimi_renderer):\n    response_tokens = kimi_tokenizer.encode(\n        \"reasoning...</think>2<|im_end|>\",\n        add_special_tokens=False,\n    )\n\n    parsed_message, parse_success = kimi_renderer.parse_response(response_tokens)\n\n    assert parse_success is True\n    assert parsed_message[\"content\"] == [\n        ThinkingPart(type=\"thinking\", thinking=\"reasoning...\"),\n        TextPart(type=\"text\", text=\"2\"),\n    ]\n\n\ndef test_kimi_k25_parse_response_streaming_restores_prefilled_think_tag(\n    kimi_tokenizer, kimi_renderer\n):\n    response_tokens = kimi_tokenizer.encode(\n        \"reasoning...</think>2<|im_end|>\",\n        add_special_tokens=False,\n    )\n\n    deltas = list(kimi_renderer.parse_response_streaming(response_tokens))\n    thinking_text = \"\".join(\n        delta.thinking for delta in deltas if isinstance(delta, StreamingThinkingDelta)\n    )\n    output_text = \"\".join(delta.text for delta in deltas if isinstance(delta, StreamingTextDelta))\n    final_message = cast(Message, deltas[-1])\n\n    assert thinking_text == \"reasoning...\"\n    assert output_text == \"2\"\n    assert final_message[\"content\"] == [\n        ThinkingPart(type=\"thinking\", thinking=\"reasoning...\"),\n        TextPart(type=\"text\", text=\"2\"),\n    ]\n\n\n# =============================================================================\n# Image Content Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\n    \"image_dimensions_and_expected_tokens\", [(2048, 1365, 3626), (17, 64, 3), (5000, 6000, 4189)]\n)\ndef test_kimi_k25_image_content(image_dimensions_and_expected_tokens: tuple[int, int, int]):\n    \"\"\"Test that image-content is encoded properly for kimi2.5\"\"\"\n    width, height, expected_tokens = image_dimensions_and_expected_tokens\n    dummy_image = Image.new(\"RGB\", (width, height))\n    messages = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"image\", \"image\": dummy_image},\n                {\"type\": \"text\", \"text\": \"Can you describe this image?\"},\n            ],\n        },\n        {\"role\": \"assistant\", \"content\": \"That looks like a blank image?\"},\n    ]\n\n    tokenizer = get_tokenizer(KIMI_K25_MODEL)\n    image_processor = get_image_processor(KIMI_K25_MODEL)\n\n    hf_output = extract_token_ids(tokenizer.apply_chat_template(messages, tokenize=True))\n\n    renderer = get_renderer(\"kimi_k25\", tokenizer, image_processor)\n    renderer_output = renderer.build_generation_prompt(messages)\n\n    # Compare HF and renderer tokens\n    hf_offset = 0\n    for chunk in renderer_output.chunks:\n        if isinstance(chunk, tinker.EncodedTextChunk):\n            assert list(chunk.tokens) == hf_output[hf_offset : hf_offset + len(chunk.tokens)]\n            hf_offset += len(chunk.tokens)\n        elif isinstance(chunk, tinker.types.image_chunk.ImageChunk):\n            assert hf_output[hf_offset : hf_offset + 1] == tokenizer.encode(\"<|media_pad|>\")\n            assert chunk.expected_tokens == expected_tokens, (\n                f\"Expected {expected_tokens} tokens for image, got {chunk.expected_tokens}\"\n            )\n            hf_offset += 1\n        else:\n            raise ValueError(f\"Unknown chunk type: {type(chunk)}\")\n    assert hf_offset == len(hf_output)\n"
  },
  {
    "path": "tinker_cookbook/renderers/kimi_k2_5_tool_declaration_ts.py",
    "content": "\"\"\"\nEncode structured tool declaration to typescript style string.\n\nCopied from kimi-k2.5-hf-tokenizer/tool_declaration_ts.py for Kimi K2.5 support.\n\"\"\"\n\nimport dataclasses\nimport json\nimport logging\nfrom collections.abc import Sequence\nfrom typing import Any\n\nlogger = logging.getLogger(__name__)\n\n_TS_INDENT = \"  \"\n_TS_FIELD_DELIMITER = \",\\n\"\n\n\nclass _SchemaRegistry:\n    \"\"\"Registry for schema definitions to handle $ref resolution\"\"\"\n\n    def __init__(self):\n        self.definitions = {}\n        self.has_self_ref = False\n\n    def register_definitions(self, defs: dict[str, Any]):\n        \"\"\"Register schema definitions from $defs section\"\"\"\n        if not defs:\n            return\n        for def_name, def_schema in defs.items():\n            self.definitions[def_name] = def_schema\n\n    def resolve_ref(self, ref: str) -> dict[str, Any]:\n        \"\"\"Resolve a reference to its schema definition\"\"\"\n        if ref == \"#\":\n            self.has_self_ref = True\n            return {\"$self_ref\": True}\n        elif ref.startswith(\"#/$defs/\"):\n            def_name = ref.split(\"/\")[-1]\n            if def_name not in self.definitions:\n                raise ValueError(f\"Reference not found: {ref}\")\n            return self.definitions[def_name]\n        else:\n            raise ValueError(f\"Unsupported reference format: {ref}\")\n\n\ndef _format_description(description: str, indent: str = \"\") -> str:\n    return \"\\n\".join([f\"{indent}// {line}\" if line else \"\" for line in description.split(\"\\n\")])\n\n\nclass _BaseType:\n    description: str\n    constraints: dict[str, Any]\n\n    def __init__(\n        self,\n        extra_props: dict[str, Any],\n        *,\n        allowed_constraint_keys: Sequence[str] = (),\n    ):\n        self.description = extra_props.get(\"description\", \"\")\n        self.constraints = {k: v for k, v in extra_props.items() if k in allowed_constraint_keys}\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        raise NotImplementedError\n\n    def format_docstring(self, indent: str) -> str:\n        lines = []\n        if self.description:\n            lines.append(_format_description(self.description, indent))\n        if self.constraints:\n            constraints_str = \", \".join(\n                f\"{k}: {v}\" for k, v in sorted(self.constraints.items(), key=lambda kv: kv[0])\n            )\n            lines.append(f\"{indent}// {constraints_str}\")\n\n        return \"\".join(x + \"\\n\" for x in lines)\n\n\nclass _ParameterTypeScalar(_BaseType):\n    type: str\n\n    def __init__(self, type: str, extra_props: dict[str, Any] | None = None):\n        self.type = type\n\n        allowed_constraint_keys: list[str] = []\n        if self.type == \"string\":\n            allowed_constraint_keys = [\"maxLength\", \"minLength\", \"pattern\"]\n        elif self.type in (\"number\", \"integer\"):\n            allowed_constraint_keys = [\"maximum\", \"minimum\"]\n\n        super().__init__(extra_props or {}, allowed_constraint_keys=allowed_constraint_keys)\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        # Map integer to number in TypeScript\n        if self.type == \"integer\":\n            return \"number\"\n        return self.type\n\n\nclass _ParameterTypeObject(_BaseType):\n    properties: list[\"_Parameter\"]\n    additional_properties: Any | None = None\n\n    def __init__(self, json_schema_object: dict[str, Any], registry: _SchemaRegistry | None = None):\n        super().__init__(json_schema_object)\n\n        self.properties = []\n        self.additional_properties = None\n\n        if not json_schema_object:\n            return\n\n        if \"$defs\" in json_schema_object and registry:\n            registry.register_definitions(json_schema_object[\"$defs\"])\n\n        self.additional_properties = json_schema_object.get(\"additionalProperties\")\n        if isinstance(self.additional_properties, dict):\n            self.additional_properties = _parse_parameter_type(self.additional_properties, registry)\n\n        if \"properties\" not in json_schema_object:\n            return\n\n        required_parameters = json_schema_object.get(\"required\", [])\n        optional_parameters = set(json_schema_object[\"properties\"].keys()) - set(\n            required_parameters\n        )\n\n        self.properties = [\n            _Parameter(\n                name=name,\n                type=_parse_parameter_type(prop, registry),\n                optional=name in optional_parameters,\n                default=prop.get(\"default\") if isinstance(prop, dict) else None,\n            )\n            for name, prop in json_schema_object[\"properties\"].items()\n        ]\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        # sort by optional, make the required parameters first\n        parameters = [p for p in self.properties if not p.optional]\n        opt_params = [p for p in self.properties if p.optional]\n\n        parameters = sorted(parameters, key=lambda p: p.name)\n        parameters.extend(sorted(opt_params, key=lambda p: p.name))\n\n        param_strs = []\n        for p in parameters:\n            one = p.to_typescript_style(indent=indent + _TS_INDENT)\n            param_strs.append(one)\n\n        if self.additional_properties is not None:\n            ap_type_str = \"any\"\n            if self.additional_properties is True:\n                ap_type_str = \"any\"\n            elif self.additional_properties is False:\n                ap_type_str = \"never\"\n            elif isinstance(self.additional_properties, _ParameterType):\n                ap_type_str = self.additional_properties.to_typescript_style(\n                    indent=indent + _TS_INDENT\n                )\n            else:\n                raise ValueError(f\"Unknown additionalProperties: {self.additional_properties}\")\n            param_strs.append(f\"{indent + _TS_INDENT}[k: string]: {ap_type_str}\")\n\n        if not param_strs:\n            return \"{}\"\n\n        params_str = _TS_FIELD_DELIMITER.join(param_strs)\n        if params_str:\n            # add new line before and after\n            params_str = f\"\\n{params_str}\\n\"\n        # always wrap with object\n        return f\"{{{params_str}{indent}}}\"\n\n\nclass _ParameterTypeArray(_BaseType):\n    item: \"_ParameterType\"\n\n    def __init__(self, json_schema_object: dict[str, Any], registry: _SchemaRegistry | None = None):\n        super().__init__(json_schema_object, allowed_constraint_keys=(\"minItems\", \"maxItems\"))\n        if json_schema_object.get(\"items\"):\n            self.item = _parse_parameter_type(json_schema_object[\"items\"], registry)\n        else:\n            self.item = _ParameterTypeScalar(type=\"any\")\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        item_docstring = self.item.format_docstring(indent + _TS_INDENT)\n        if item_docstring:\n            return (\n                \"Array<\\n\"\n                + item_docstring\n                + indent\n                + _TS_INDENT\n                + self.item.to_typescript_style(indent=indent + _TS_INDENT)\n                + \"\\n\"\n                + indent\n                + \">\"\n            )\n        else:\n            return f\"Array<{self.item.to_typescript_style(indent=indent)}>\"\n\n\nclass _ParameterTypeEnum(_BaseType):\n    # support scalar types only\n    enum: list[str | int | float | bool | None]\n\n    def __init__(self, json_schema_object: dict[str, Any]):\n        super().__init__(json_schema_object)\n        self.enum = json_schema_object[\"enum\"]\n\n        # Validate enum values against declared type if present\n        if \"type\" in json_schema_object:\n            typ = json_schema_object[\"type\"]\n            if isinstance(typ, list):\n                if len(typ) == 1:\n                    typ = typ[0]\n                elif len(typ) == 2:\n                    if \"null\" not in typ:\n                        raise ValueError(f\"Enum type {typ} is not supported\")\n                    else:\n                        typ = typ[0] if typ[0] != \"null\" else typ[1]\n                else:\n                    raise ValueError(f\"Enum type {typ} is not supported\")\n            for val in self.enum:\n                if val is None:\n                    continue\n                if typ == \"string\" and not isinstance(val, str):\n                    raise ValueError(f\"Enum value {val} is not a string\")\n                elif typ == \"number\" and not isinstance(val, (int, float)):\n                    raise ValueError(f\"Enum value {val} is not a number\")\n                elif typ == \"integer\" and not isinstance(val, int):\n                    raise ValueError(f\"Enum value {val} is not an integer\")\n                elif typ == \"boolean\" and not isinstance(val, bool):\n                    raise ValueError(f\"Enum value {val} is not a boolean\")\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        return \" | \".join([f'\"{e}\"' if isinstance(e, str) else str(e) for e in self.enum])\n\n\nclass _ParameterTypeAnyOf(_BaseType):\n    types: list[\"_ParameterType\"]\n\n    def __init__(\n        self,\n        json_schema_object: dict[str, Any],\n        registry: _SchemaRegistry | None = None,\n    ):\n        super().__init__(json_schema_object)\n        self.types = [_parse_parameter_type(t, registry) for t in json_schema_object[\"anyOf\"]]\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        return \" | \".join([t.to_typescript_style(indent=indent) for t in self.types])\n\n\nclass _ParameterTypeUnion(_BaseType):\n    types: list[str]\n\n    def __init__(self, json_schema_object: dict[str, Any]):\n        super().__init__(json_schema_object)\n\n        mapping = {\n            \"string\": \"string\",\n            \"number\": \"number\",\n            \"integer\": \"number\",\n            \"boolean\": \"boolean\",\n            \"null\": \"null\",\n            \"object\": \"{}\",\n            \"array\": \"Array<any>\",\n        }\n        self.types = [mapping[t] for t in json_schema_object[\"type\"]]\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        return \" | \".join(self.types)\n\n\nclass _ParameterTypeRef(_BaseType):\n    ref_name: str\n    is_self_ref: bool = False\n\n    def __init__(self, json_schema_object: dict[str, Any], registry: _SchemaRegistry):\n        super().__init__(json_schema_object)\n\n        ref = json_schema_object[\"$ref\"]\n        resolved_schema = registry.resolve_ref(ref)\n\n        if resolved_schema.get(\"$self_ref\", False):\n            self.ref_name = \"parameters\"\n            self.is_self_ref = True\n        else:\n            self.ref_name = ref.split(\"/\")[-1]\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        return self.ref_name\n\n\n_ParameterType = (\n    _ParameterTypeScalar\n    | _ParameterTypeObject\n    | _ParameterTypeArray\n    | _ParameterTypeEnum\n    | _ParameterTypeAnyOf\n    | _ParameterTypeUnion\n    | _ParameterTypeRef\n)\n\n\n@dataclasses.dataclass\nclass _Parameter:\n    \"\"\"\n    A parameter in a function, or a field in a object.\n    It consists of the type as well as the name.\n    \"\"\"\n\n    type: _ParameterType\n    name: str = \"_\"\n    optional: bool = True\n    default: Any | None = None\n\n    @classmethod\n    def parse_extended(cls, attributes: dict[str, Any]) -> \"_Parameter\":\n        if not attributes:\n            raise ValueError(\"attributes is empty\")\n\n        return cls(\n            name=attributes.get(\"name\", \"_\"),\n            type=_parse_parameter_type(attributes),\n            optional=attributes.get(\"optional\", False),\n            default=attributes.get(\"default\"),\n        )\n\n    def to_typescript_style(self, indent: str = \"\") -> str:\n        comments = self.type.format_docstring(indent)\n\n        if self.default is not None:\n            default_repr = (\n                json.dumps(self.default, ensure_ascii=False)\n                if not isinstance(self.default, (int, float, bool))\n                else repr(self.default)\n            )\n            comments += f\"{indent}// Default: {default_repr}\\n\"\n\n        return (\n            comments\n            + f\"{indent}{self.name}{'?' if self.optional else ''}: {self.type.to_typescript_style(indent=indent)}\"\n        )\n\n\ndef _parse_parameter_type(\n    json_schema_object: dict[str, Any] | bool, registry: _SchemaRegistry | None = None\n) -> _ParameterType:\n    if isinstance(json_schema_object, bool):\n        if json_schema_object:\n            return _ParameterTypeScalar(type=\"any\")\n        else:\n            logger.warning(\n                f\"Warning: Boolean value {json_schema_object} is not supported, use null instead.\"\n            )\n            return _ParameterTypeScalar(type=\"null\")\n\n    if \"$ref\" in json_schema_object and registry:\n        return _ParameterTypeRef(json_schema_object, registry)\n\n    if \"anyOf\" in json_schema_object:\n        return _ParameterTypeAnyOf(json_schema_object, registry)\n    elif \"enum\" in json_schema_object:\n        return _ParameterTypeEnum(json_schema_object)\n    elif \"type\" in json_schema_object:\n        typ = json_schema_object[\"type\"]\n        if isinstance(typ, list):\n            return _ParameterTypeUnion(json_schema_object)\n        elif typ == \"object\":\n            return _ParameterTypeObject(json_schema_object, registry)\n        elif typ == \"array\":\n            return _ParameterTypeArray(json_schema_object, registry)\n        else:\n            return _ParameterTypeScalar(typ, json_schema_object)\n    elif json_schema_object == {}:\n        return _ParameterTypeScalar(type=\"any\")\n    else:\n        raise ValueError(f\"Invalid JSON Schema object: {json_schema_object}\")\n\n\ndef _openai_function_to_typescript_style(\n    function: dict[str, Any],\n) -> str:\n    \"\"\"Convert OpenAI function definition (dict) to TypeScript style string.\"\"\"\n    registry = _SchemaRegistry()\n    parameters = function.get(\"parameters\") or {}\n    parsed = _ParameterTypeObject(parameters, registry)\n\n    interfaces = []\n    root_interface_name = None\n    if registry.has_self_ref:\n        root_interface_name = \"parameters\"\n        params_str = _TS_FIELD_DELIMITER.join(\n            [p.to_typescript_style(indent=_TS_INDENT) for p in parsed.properties]\n        )\n        params_str = f\"\\n{params_str}\\n\" if params_str else \"\"\n        interface_def = f\"interface {root_interface_name} {{{params_str}}}\"\n        interfaces.append(interface_def)\n\n    definitions_copy = dict(registry.definitions)\n    for def_name, def_schema in definitions_copy.items():\n        obj_type = _parse_parameter_type(def_schema, registry)\n        params_str = obj_type.to_typescript_style()\n\n        description_part = \"\"\n        if obj_description := def_schema.get(\"description\", \"\"):\n            description_part = _format_description(obj_description) + \"\\n\"\n\n        interface_def = f\"{description_part}interface {def_name} {params_str}\"\n        interfaces.append(interface_def)\n\n    interface_str = \"\\n\".join(interfaces)\n    function_name = function.get(\"name\", \"function\")\n    if root_interface_name:\n        type_def = f\"type {function_name} = (_: {root_interface_name}) => any;\"\n    else:\n        params_str = parsed.to_typescript_style()\n        type_def = f\"type {function_name} = (_: {params_str}) => any;\"\n\n    description = function.get(\"description\")\n    return \"\\n\".join(\n        filter(\n            bool,\n            [\n                interface_str,\n                ((description and _format_description(description)) or \"\"),\n                type_def,\n            ],\n        )\n    )\n\n\ndef encode_tools_to_typescript_style(\n    tools: list[dict[str, Any]],\n) -> str:\n    \"\"\"\n    Convert tools (list of dict) to TypeScript style string.\n\n    Supports OpenAI format: {\"type\": \"function\", \"function\": {...}}\n\n    Args:\n        tools: List of tool definitions in dict format\n\n    Returns:\n        TypeScript style string representation of the tools\n    \"\"\"\n    if not tools:\n        return \"\"\n\n    functions = []\n\n    for tool in tools:\n        tool_type = tool.get(\"type\")\n        if tool_type == \"function\":\n            func_def = tool.get(\"function\", {})\n            if func_def:\n                functions.append(_openai_function_to_typescript_style(func_def))\n        else:\n            # Skip unsupported tool types (like \"_plugin\")\n            continue\n\n    if not functions:\n        return \"\"\n\n    functions_str = \"\\n\".join(functions)\n    result = \"# Tools\\n\\n\"\n\n    if functions_str:\n        result += \"## functions\\nnamespace functions {\\n\"\n        result += functions_str + \"\\n\"\n        result += \"}\\n\"\n\n    return result\n"
  },
  {
    "path": "tinker_cookbook/renderers/kimi_k2_test.py",
    "content": "\"\"\"Tests specific to KimiK2Renderer (streaming parsing, thinking stripping, supervised examples).\"\"\"\n\nfrom typing import TypeGuard\n\nimport pytest\n\nfrom tinker_cookbook.renderers import (\n    Message,\n    StreamingMessageHeader,\n    StreamingTextDelta,\n    StreamingThinkingDelta,\n    TextPart,\n    ThinkingPart,\n    ToolCall,\n    TrainOnWhat,\n    get_renderer,\n)\nfrom tinker_cookbook.renderers.base import ensure_list\nfrom tinker_cookbook.renderers.kimi_k2 import KimiK2Renderer\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\ndef _is_message(obj) -> TypeGuard[Message]:\n    \"\"\"Check if object is a Message dict (TypedDict doesn't support isinstance).\"\"\"\n    return isinstance(obj, dict) and \"role\" in obj and \"content\" in obj\n\n\n# =============================================================================\n# Conversation helpers\n# =============================================================================\n\n\ndef _get_basic_4turn() -> list[Message]:\n    return [\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\"role\": \"assistant\", \"content\": \"The answer is 4.\"},\n        {\"role\": \"user\", \"content\": \"And what is 3+3?\"},\n        {\"role\": \"assistant\", \"content\": \"The answer is 6.\"},\n    ]\n\n\ndef _get_tool_call_conversation() -> list[Message]:\n    return [\n        {\"role\": \"user\", \"content\": \"What's the weather in San Francisco?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": \"I'll check the weather for you.\",\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"San Francisco\"}',\n                    ),\n                    id=\"call_123\",\n                )\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"content\": '{\"temperature\": 72, \"condition\": \"sunny\"}',\n            \"tool_call_id\": \"call_123\",\n            \"name\": \"get_weather\",\n        },\n        {\"role\": \"assistant\", \"content\": \"The weather in San Francisco is sunny with 72°F.\"},\n    ]\n\n\n# =============================================================================\n# KimiK2 Streaming Parsing Tests\n# =============================================================================\n\n\ndef test_kimi_streaming_simple_text():\n    \"\"\"Test streaming parsing of simple text response.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = KimiK2Renderer(tokenizer)\n\n    response_str = \"Hello, world!<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert deltas[0].role == \"assistant\"\n\n    assert _is_message(deltas[-1])\n    assert deltas[-1][\"role\"] == \"assistant\"\n\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n    assert \"Hello, world!\" in text_content\n\n\ndef test_kimi_streaming_with_thinking():\n    \"\"\"Test streaming parsing with thinking blocks.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = KimiK2Renderer(tokenizer)\n\n    response_str = \"<think>Let me reason about this.</think>The answer is 42.<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert deltas[0].role == \"assistant\"\n\n    thinking_content = \"\".join(d.thinking for d in deltas if isinstance(d, StreamingThinkingDelta))\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n\n    assert \"Let me reason about this.\" in thinking_content\n    assert \"The answer is 42.\" in text_content\n\n    final_message = deltas[-1]\n    assert _is_message(final_message)\n\n\ndef test_kimi_streaming_matches_batch():\n    \"\"\"Test that streaming parse produces same final message as batch parse.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = KimiK2Renderer(tokenizer)\n\n    response_str = \"<think>Step 1: Analyze.\\nStep 2: Compute.</think>The result is 123.<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    batch_message, batch_success = renderer.parse_response(response_tokens)\n    assert batch_success\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n    streaming_message = deltas[-1]\n\n    assert _is_message(streaming_message)\n    assert streaming_message[\"role\"] == batch_message[\"role\"]\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n\n\ndef test_kimi_streaming_content_index_increments():\n    \"\"\"Test that content_index increments when switching content types.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = KimiK2Renderer(tokenizer)\n\n    response_str = \"<think>thinking</think>text<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    thinking_indices = [d.content_index for d in deltas if isinstance(d, StreamingThinkingDelta)]\n    text_indices = [d.content_index for d in deltas if isinstance(d, StreamingTextDelta)]\n\n    if thinking_indices and text_indices:\n        assert max(text_indices) > min(thinking_indices)\n\n\ndef test_kimi_streaming_multiple_think_blocks():\n    \"\"\"Test streaming with multiple interleaved think blocks.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = KimiK2Renderer(tokenizer)\n\n    response_str = \"<think>first thought</think>partial<think>second thought</think>final<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    batch_message, _ = renderer.parse_response(response_tokens)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    thinking_content = \"\".join(d.thinking for d in deltas if isinstance(d, StreamingThinkingDelta))\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n\n    assert \"first thought\" in thinking_content\n    assert \"second thought\" in thinking_content\n    assert \"partial\" in text_content\n    assert \"final\" in text_content\n\n    streaming_message = deltas[-1]\n    assert _is_message(streaming_message)\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n\n\ndef test_kimi_streaming_empty_response():\n    \"\"\"Test streaming parsing of empty/minimal response.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = KimiK2Renderer(tokenizer)\n\n    response_str = \"<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert _is_message(deltas[-1])\n\n\ndef test_kimi_streaming_no_unnecessary_buffering():\n    \"\"\"Test that we don't buffer more than necessary when no tag prefix matches.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = KimiK2Renderer(tokenizer)\n\n    response_str = \"Hello world<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n    assert text_content == \"Hello world\"\n\n\ndef test_kimi_streaming_with_emoji():\n    \"\"\"Test that streaming parser handles emoji correctly.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = KimiK2Renderer(tokenizer)\n\n    response_str = \"<think>Let me think 🤔</think>Here's a party 🎉!<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    thinking_content = \"\".join(d.thinking for d in deltas if isinstance(d, StreamingThinkingDelta))\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n\n    assert \"�\" not in thinking_content, f\"Thinking has replacement chars: {thinking_content!r}\"\n    assert \"�\" not in text_content, f\"Text has replacement chars: {text_content!r}\"\n\n    assert \"🤔\" in thinking_content, f\"Missing thinking emoji in: {thinking_content!r}\"\n    assert \"🎉\" in text_content, f\"Missing party emoji in: {text_content!r}\"\n\n    final_messages = [d for d in deltas if isinstance(d, dict) and \"role\" in d]\n    assert len(final_messages) == 1\n    final = final_messages[0]\n\n    content = final[\"content\"]\n    if isinstance(content, list):\n        final_thinking = \"\".join(p[\"thinking\"] for p in content if p[\"type\"] == \"thinking\")\n        final_text = \"\".join(p[\"text\"] for p in content if p[\"type\"] == \"text\")\n    else:\n        final_thinking = \"\"\n        final_text = content\n\n    assert \"🤔\" in final_thinking, \"Final message missing thinking emoji\"\n    assert \"🎉\" in final_text, \"Final message missing party emoji\"\n\n\n# =============================================================================\n# Streaming vs Batch Equivalence Tests\n# =============================================================================\n\n\ndef _assert_streaming_matches_batch(renderer, response_str: str):\n    \"\"\"Helper: verify streaming and batch parsing produce identical results.\"\"\"\n    tokenizer = renderer.tokenizer\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    batch_message, batch_success = renderer.parse_response(response_tokens)\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert len(deltas) >= 2, \"Should have at least header + final message\"\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert _is_message(deltas[-1])\n\n    streaming_message = deltas[-1]\n    assert streaming_message[\"role\"] == batch_message[\"role\"]\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n    assert streaming_message.get(\"tool_calls\") == batch_message.get(\"tool_calls\")\n    assert streaming_message.get(\"unparsed_tool_calls\") == batch_message.get(\"unparsed_tool_calls\")\n\n    # Verify streamed deltas reconstruct the content\n    thinking_from_deltas = \"\".join(\n        d.thinking for d in deltas if isinstance(d, StreamingThinkingDelta)\n    )\n    text_from_deltas = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n\n    batch_content = batch_message[\"content\"]\n    if isinstance(batch_content, list):\n        expected_thinking = \"\".join(p[\"thinking\"] for p in batch_content if p[\"type\"] == \"thinking\")\n        expected_text = \"\".join(p[\"text\"] for p in batch_content if p[\"type\"] == \"text\")\n    else:\n        expected_thinking = \"\"\n        expected_text = batch_content\n\n    assert thinking_from_deltas == expected_thinking\n    # Text deltas may include tool call markup before final parsing strips it\n    if not batch_message.get(\"tool_calls\") and not batch_message.get(\"unparsed_tool_calls\"):\n        assert text_from_deltas == expected_text\n\n    return deltas, batch_message\n\n\nclass TestKimiK2StreamingBatchEquivalence:\n    \"\"\"Verify parse_response_streaming matches parse_response for all patterns.\"\"\"\n\n    @pytest.fixture\n    def renderer(self):\n        tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n        return KimiK2Renderer(tokenizer)\n\n    def test_simple_text(self, renderer):\n        _assert_streaming_matches_batch(renderer, \"Hello, world!<|im_end|>\")\n\n    def test_thinking_then_text(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>Let me reason step by step.\\n1. First...\\n2. Then...</think>\"\n            \"The answer is 42.<|im_end|>\",\n        )\n\n    def test_empty_thinking(self, renderer):\n        _assert_streaming_matches_batch(renderer, \"<think></think>Direct answer.<|im_end|>\")\n\n    def test_long_thinking(self, renderer):\n        thinking = (\n            \"First, let me understand the problem.\\n\\n\"\n            \"Key concepts:\\n1. Superposition\\n2. Measurement\\n3. Non-locality\\n\\n\"\n            \"I should explain this clearly.\"\n        )\n        _assert_streaming_matches_batch(\n            renderer, f\"<think>{thinking}</think>Quantum entanglement links particles.<|im_end|>\"\n        )\n\n    def test_multiple_think_blocks(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>first thought</think>partial<think>second thought</think>final<|im_end|>\",\n        )\n\n    def test_empty_response(self, renderer):\n        _assert_streaming_matches_batch(renderer, \"<|im_end|>\")\n\n    def test_whitespace_only(self, renderer):\n        _assert_streaming_matches_batch(renderer, \"   \\n\\t  <|im_end|>\")\n\n    def test_special_characters(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>x² + y² = r²</think>Special chars: <>&\\\"'`~!@#$%^&*()<|im_end|>\",\n        )\n\n    def test_emoji(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer, \"<think>🤔 thinking 💭</think>Answer 🎉✨!<|im_end|>\"\n        )\n\n    def test_code_blocks(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>Need a function.</think>\"\n            \"```python\\ndef hello():\\n    print('world')\\n```<|im_end|>\",\n        )\n\n    def test_html_like_content(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>HTML example</think><div><p>Hello</p></div><|im_end|>\",\n        )\n\n    def test_tool_call_with_thinking(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>I need to search.</think>\"\n            \"<|tool_calls_section_begin|>\"\n            \"<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>\"\n            '{\"query\": \"quantum physics\"}'\n            \"<|tool_call_end|>\"\n            \"<|tool_calls_section_end|><|im_end|>\",\n        )\n\n    def test_tool_call_without_thinking(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<|tool_calls_section_begin|>\"\n            \"<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>\"\n            '{\"city\": \"San Francisco\"}'\n            \"<|tool_call_end|>\"\n            \"<|tool_calls_section_end|><|im_end|>\",\n        )\n\n    def test_text_then_tool_call(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"Let me look that up.\"\n            \"<|tool_calls_section_begin|>\"\n            \"<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>\"\n            '{\"query\": \"weather\"}'\n            \"<|tool_call_end|>\"\n            \"<|tool_calls_section_end|><|im_end|>\",\n        )\n\n    def test_multiple_tool_calls(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>Two calls needed.</think>\"\n            \"<|tool_calls_section_begin|>\"\n            \"<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>\"\n            '{\"query\": \"python\"}'\n            \"<|tool_call_end|>\"\n            \"<|tool_call_begin|>functions.calculate:1<|tool_call_argument_begin|>\"\n            '{\"expression\": \"2+2\"}'\n            \"<|tool_call_end|>\"\n            \"<|tool_calls_section_end|><|im_end|>\",\n        )\n\n    def test_multiline_thinking(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>\\nStep 1\\n\\nStep 2\\n\\nStep 3\\n</think>\\nResult.\\n<|im_end|>\",\n        )\n\n    def test_no_end_token(self, renderer):\n        \"\"\"Truncated response — streaming should still parse think blocks.\"\"\"\n        tokenizer = renderer.tokenizer\n        response_tokens = tokenizer.encode(\n            \"<think>reasoning</think>partial\", add_special_tokens=False\n        )\n\n        deltas = list(renderer.parse_response_streaming(response_tokens))\n        final = deltas[-1]\n        assert _is_message(final)\n        # Even without end token, streaming should parse think blocks\n        content = final[\"content\"]\n        assert isinstance(content, list), \"Truncated response should still parse think blocks\"\n        thinking = [p for p in content if p[\"type\"] == \"thinking\"]\n        text = [p for p in content if p[\"type\"] == \"text\"]\n        assert len(thinking) == 1 and thinking[0][\"thinking\"] == \"reasoning\"\n        assert len(text) == 1 and text[0][\"text\"] == \"partial\"\n\n    def test_content_index_ordering(self, renderer):\n        \"\"\"Content index strictly increases across type transitions.\"\"\"\n        response_tokens = renderer.tokenizer.encode(\n            \"<think>t1</think>x1<think>t2</think>x2<|im_end|>\", add_special_tokens=False\n        )\n        deltas = list(renderer.parse_response_streaming(response_tokens))\n\n        indexed = []\n        for d in deltas:\n            if isinstance(d, StreamingThinkingDelta):\n                indexed.append((\"thinking\", d.content_index))\n            elif isinstance(d, StreamingTextDelta):\n                indexed.append((\"text\", d.content_index))\n\n        indices = [idx for _, idx in indexed]\n        assert indices == sorted(indices), f\"Not monotonic: {indexed}\"\n        for i in range(1, len(indexed)):\n            if indexed[i][0] != indexed[i - 1][0]:\n                assert indexed[i][1] > indexed[i - 1][1]\n\n\n# =============================================================================\n# KimiK2 Thinking Stripping / Preservation Tests\n# =============================================================================\n\n\ndef test_kimi_k2_thinking_stripped_when_no_suffix_messages():\n    \"\"\"\n    Kimi K2 should preserve thinking only after the last non-tool-call assistant.\n    This test checks that the history thinking is stripped with the presence of a non-tool-call assistant.\n    \"\"\"\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    messages: list[Message] = [\n        {\"role\": \"user\", \"content\": \"Q\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think A\"),\n            ],\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\":\"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\"role\": \"tool\", \"content\": '{\"temperature\": 72}', \"tool_call_id\": \"call_1\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think B\"),\n                TextPart(type=\"text\", text=\"A\"),\n            ],\n        },\n    ]\n\n    model_input, _ = renderer.build_supervised_example(messages)\n    decoded = tokenizer.decode(model_input.to_ints())\n\n    assert \"think A\" in decoded, f\"Non-suffix thinking should be preserved: {decoded}\"\n    assert \"think B\" in decoded, f\"Non-suffix thinking should be preserved: {decoded}\"\n    assert \"A\" in decoded, f\"Non-suffix text should be preserved: {decoded}\"\n\n    model_input = renderer.build_generation_prompt(messages)\n    decoded = tokenizer.decode(model_input.to_ints())\n\n    assert \"think A\" not in decoded, f\"History thinking should be stripped: {decoded}\"\n    assert \"think B\" not in decoded, f\"History thinking should be stripped: {decoded}\"\n    assert \"A\" in decoded, f\"History text should be preserved: {decoded}\"\n\n\ndef test_kimi_k2_thinking_preserved_in_suffix_after_last_non_tool_call():\n    \"\"\"\n    Kimi K2 should preserve thinking only after the last non-tool-call assistant.\n    Suffix thinking is preserved but history thinking is stripped relative to the\n    position of the last non-tool-call assistant.\n    \"\"\"\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    messages: list[Message] = [\n        {\"role\": \"user\", \"content\": \"Q1\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think A\"),\n                TextPart(type=\"text\", text=\"A1\"),\n            ],\n            \"tool_calls\": [],\n        },\n        {\"role\": \"user\", \"content\": \"Q2\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think B\"),\n            ],\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\":\"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\"role\": \"tool\", \"content\": '{\"temperature\": 72}', \"tool_call_id\": \"call_1\"},\n    ]\n\n    model_input, _ = renderer.build_supervised_example(messages)\n    decoded = tokenizer.decode(model_input.to_ints())\n\n    assert \"think A\" not in decoded, f\"History thinking should be stripped: {decoded}\"\n    assert \"A1\" in decoded, f\"History text should be preserved: {decoded}\"\n    assert \"think B\" in decoded, f\"Suffix thinking should be preserved: {decoded}\"\n\n    model_input = renderer.build_generation_prompt(messages)\n    decoded = tokenizer.decode(model_input.to_ints())\n\n    assert \"think A\" not in decoded, f\"History thinking should be stripped: {decoded}\"\n    assert \"A1\" in decoded, f\"History text should be preserved: {decoded}\"\n    assert \"think B\" in decoded, f\"Suffix thinking should be preserved: {decoded}\"\n\n\ndef test_kimi_k2_thinking_preserved_when_no_non_tool_call_assistant():\n    \"\"\"\n    When no non-tool-call assistant exists, all thinking should be preserved.\n    \"\"\"\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    messages: list[Message] = [\n        {\"role\": \"user\", \"content\": \"Q\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think A\"),\n            ],\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\":\"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\"role\": \"tool\", \"content\": '{\"temperature\": 72}', \"tool_call_id\": \"call_1\"},\n    ]\n\n    model_input, _ = renderer.build_supervised_example(messages)\n    decoded = tokenizer.decode(model_input.to_ints())\n\n    assert \"think A\" in decoded, f\"Suffix thinking should be preserved: {decoded}\"\n\n\n# =============================================================================\n# KimiK2 build_supervised_examples Tests\n# =============================================================================\n\n\ndef test_kimi_k2_build_supervised_examples_last_assistant_matches():\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer: KimiK2Renderer = get_renderer(\"kimi_k2\", tokenizer)  # type: ignore\n\n    messages = _get_basic_4turn()\n\n    single_input, single_weights = renderer.build_supervised_example(messages)\n    examples = renderer.build_supervised_examples(\n        messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_MESSAGE\n    )\n\n    assert len(examples) == 1, \"Expected a single supervised example\"\n    list_input, list_weights = examples[0]\n    assert list_input.to_ints() == single_input.to_ints()\n    assert list_weights.tolist() == single_weights.tolist()\n\n\ndef test_kimi_k2_build_supervised_examples_all_assistant_matches():\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer: KimiK2Renderer = get_renderer(\"kimi_k2\", tokenizer)  # type: ignore\n\n    messages: list[Message] = [\n        {\"role\": \"user\", \"content\": \"Q1\"},\n        {\"role\": \"assistant\", \"content\": \"A1\"},\n        {\"role\": \"user\", \"content\": \"Q2\"},\n        {\"role\": \"assistant\", \"content\": \"A2\"},\n        {\"role\": \"user\", \"content\": \"Q3\"},\n        {\"role\": \"assistant\", \"content\": \"A3\"},\n    ]\n\n    examples = renderer.build_supervised_examples(\n        messages, train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES\n    )\n\n    assert len(examples) == 3, (\n        \"Expected one example per user turn after the first and one for the full conversation\"\n    )\n\n    ex0_tokens = examples[0][0].to_ints()\n    ex1_tokens = examples[1][0].to_ints()\n    ex2_tokens = examples[2][0].to_ints()\n    ex0_decoded = tokenizer.decode(ex0_tokens)\n    ex1_decoded = tokenizer.decode(ex1_tokens)\n    ex2_decoded = tokenizer.decode(ex2_tokens)\n\n    assert \"A1\" in ex0_decoded\n    assert \"A2\" not in ex0_decoded\n    assert \"A3\" not in ex0_decoded\n\n    assert \"A1\" in ex1_decoded\n    assert \"A2\" in ex1_decoded\n    assert \"A3\" not in ex1_decoded\n\n    assert \"A1\" in ex2_decoded\n    assert \"A2\" in ex2_decoded\n    assert \"A3\" in ex2_decoded\n\n\ndef test_kimi_k2_build_supervised_examples_warns_on_non_assistant_mode():\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer: KimiK2Renderer = get_renderer(\"kimi_k2\", tokenizer)  # type: ignore\n\n    messages = _get_basic_4turn()\n\n    with pytest.warns(UserWarning, match=\"does not satisfy the extension property\"):\n        examples = renderer.build_supervised_examples(\n            messages, train_on_what=TrainOnWhat.ALL_MESSAGES\n        )\n\n    assert len(examples) == 2, (\n        \"Expected one example for the full conversation and one for the last user turn\"\n    )\n    ex0_tokens = examples[0][0].to_ints()\n    ex1_tokens = examples[1][0].to_ints()\n    ex0_decoded = tokenizer.decode(ex0_tokens)\n    ex1_decoded = tokenizer.decode(ex1_tokens)\n\n    assert \"2+2\" in ex0_decoded\n    assert \"4\" in ex0_decoded\n    assert \"3+3\" not in ex0_decoded\n    assert \"6\" not in ex0_decoded\n\n    assert \"2+2\" in ex1_decoded\n    assert \"4\" in ex1_decoded\n    assert \"3+3\" in ex1_decoded\n    assert \"6\" in ex1_decoded\n\n\ndef test_kimi_k2_build_supervised_examples_all_assistant_matches_with_tool_calls():\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer: KimiK2Renderer = get_renderer(\"kimi_k2\", tokenizer)  # type: ignore\n\n    messages: list[Message] = [\n        {\"role\": \"user\", \"content\": \"Q\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think A\"),\n            ],\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\":\"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\"role\": \"tool\", \"content\": '{\"temperature\": 72}', \"tool_call_id\": \"call_1\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think B\"),\n                TextPart(type=\"text\", text=\"A\"),\n            ],\n        },\n        {\"role\": \"user\", \"content\": \"Q2\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think C\"),\n            ],\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\":\"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\"role\": \"tool\", \"content\": '{\"temperature\": 72}', \"tool_call_id\": \"call_1\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"think D\"),\n                TextPart(type=\"text\", text=\"A2\"),\n            ],\n        },\n    ]\n\n    examples = renderer.build_supervised_examples(\n        messages, train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES\n    )\n\n    assert len(examples) == 2\n    example0_input, example0_weights = examples[0]\n    example1_input, example1_weights = examples[1]\n\n    expected_input, expected_weights = renderer.build_supervised_example(\n        messages[:4], train_on_what=TrainOnWhat.LAST_ASSISTANT_TURN\n    )\n    all_assist_input, all_assist_weights = renderer.build_supervised_example(\n        messages[:4], train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES\n    )\n\n    assert example0_input.to_ints() == expected_input.to_ints()\n    assert example0_weights.tolist() == expected_weights.tolist()\n    # since we only have one turn in `messages[:4]`, the weights should be the same\n    assert example0_weights.tolist() == all_assist_weights.tolist()\n\n    expected_input, expected_weights = renderer.build_supervised_example(\n        messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_TURN\n    )\n    all_assist_input, all_assist_weights = renderer.build_supervised_example(\n        messages, train_on_what=TrainOnWhat.ALL_ASSISTANT_MESSAGES\n    )\n\n    assert example1_input.to_ints() == expected_input.to_ints()\n    assert example1_weights.tolist() == expected_weights.tolist()\n    assert example1_weights.tolist() != all_assist_weights.tolist()\n"
  },
  {
    "path": "tinker_cookbook/renderers/kimi_k2_tool_declaration_test.py",
    "content": "\"\"\"Tests for Kimi K2 tool declaration rendering.\"\"\"\n\nimport json\n\nimport pytest\nfrom transformers import AutoTokenizer\n\nfrom tinker_cookbook.renderers import get_renderer\nfrom tinker_cookbook.renderers.base import Message, ToolSpec, ensure_text\nfrom tinker_cookbook.renderers.testing_utils import extract_token_ids\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\n@pytest.mark.parametrize(\n    \"tools,expected_order\",\n    [\n        # Single tool\n        (\n            [\n                {\n                    \"name\": \"get_weather\",\n                    \"description\": \"Get weather for a location\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"City name\"}},\n                        \"required\": [\"location\"],\n                    },\n                }\n            ],\n            [\"tool_declare\", \"system\"],\n        ),\n        # Multiple tools\n        (\n            [\n                {\n                    \"name\": \"tool_a\",\n                    \"description\": \"Tool A\",\n                    \"parameters\": {\"type\": \"object\", \"properties\": {}},\n                },\n                {\n                    \"name\": \"tool_b\",\n                    \"description\": \"Tool B\",\n                    \"parameters\": {\"type\": \"object\", \"properties\": {}},\n                },\n            ],\n            [\"tool_declare\", \"system\"],\n        ),\n    ],\n)\ndef test_tool_declaration_message_order(tools, expected_order):\n    \"\"\"Test that tool_declare message comes before system message.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    messages = renderer.create_conversation_prefix_with_tools(tools, \"\")\n\n    actual_order = [msg[\"role\"] for msg in messages]\n    assert actual_order == expected_order, (\n        f\"Expected message order {expected_order}, got {actual_order}. \"\n        f\"Tool declaration should come BEFORE system message per HF chat template.\"\n    )\n\n\ndef test_tool_declaration_no_duplicate_system():\n    \"\"\"Test that tool declaration doesn't result in duplicate system messages.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    tools: list[ToolSpec] = [\n        {\"name\": \"test\", \"description\": \"Test\", \"parameters\": {\"type\": \"object\", \"properties\": {}}}\n    ]\n    prefix = renderer.create_conversation_prefix_with_tools(tools, \"\")\n    messages = prefix + [{\"role\": \"user\", \"content\": \"Test\"}]\n\n    # Build generation prompt (this triggers _ensure_system_message)\n    prompt = renderer.build_generation_prompt(messages)\n    prompt_str = tokenizer.decode(prompt.to_ints())\n\n    # Count occurrences of system messages\n    system_count = prompt_str.count(\"<|im_system|>system<|im_middle|>\")\n\n    assert system_count == 1, (\n        f\"Expected exactly 1 system message, found {system_count}. Prompt:\\n{prompt_str[:500]}\"\n    )\n\n\ndef test_tool_json_keys_are_sorted():\n    \"\"\"Test that tool declaration JSON has sorted keys at all nesting levels.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    tools: list[ToolSpec] = [\n        {\n            \"name\": \"get_weather\",\n            \"description\": \"Get weather\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"unit\": {\"type\": \"string\", \"description\": \"Temperature unit\"},\n                    \"location\": {\"type\": \"string\", \"description\": \"City name\"},\n                },\n                \"required\": [\"location\", \"unit\"],\n            },\n        }\n    ]\n\n    messages = renderer.create_conversation_prefix_with_tools(tools, \"\")\n    tool_declare_content = ensure_text(messages[0][\"content\"])\n\n    # Parse the JSON to check key ordering\n    tools_parsed = json.loads(tool_declare_content)\n\n    # Check top-level keys (should be alphabetically sorted)\n    top_level_keys = list(tools_parsed[0].keys())\n    sorted_top_level = sorted(top_level_keys)\n    assert top_level_keys == sorted_top_level, (\n        f\"Top-level keys not sorted: {top_level_keys} != {sorted_top_level}\"\n    )\n\n    # Check function object keys\n    function_keys = list(tools_parsed[0][\"function\"].keys())\n    sorted_function_keys = sorted(function_keys)\n    assert function_keys == sorted_function_keys, (\n        f\"Function keys not sorted: {function_keys} != {sorted_function_keys}\"\n    )\n\n    # Check nested parameters keys\n    params_keys = list(tools_parsed[0][\"function\"][\"parameters\"].keys())\n    sorted_params_keys = sorted(params_keys)\n    assert params_keys == sorted_params_keys, (\n        f\"Parameters keys not sorted: {params_keys} != {sorted_params_keys}\"\n    )\n\n\ndef test_tool_declaration_matches_hf_tokens():\n    \"\"\"Test that tool declaration produces identical tokens to HuggingFace.\"\"\"\n    # Define tools in ToolSpec format (what tinker-cookbook accepts)\n    tools_toolspec: list[ToolSpec] = [\n        {\n            \"name\": \"get_weather\",\n            \"description\": \"Get current weather\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\"type\": \"string\", \"description\": \"City and state\"},\n                    \"unit\": {\n                        \"type\": \"string\",\n                        \"enum\": [\"C\", \"F\"],\n                        \"description\": \"Temperature unit\",\n                    },\n                },\n                \"required\": [\"location\", \"unit\"],\n            },\n        }\n    ]\n\n    # Convert to OpenAI format for HF (tinker-cookbook does this wrapping internally)\n    tools_openai = [{\"type\": \"function\", \"function\": tool} for tool in tools_toolspec]\n\n    messages: list[Message] = [{\"role\": \"user\", \"content\": \"What's the weather in SF?\"}]\n\n    # Tinker-cookbook approach\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n    convo = renderer.create_conversation_prefix_with_tools(tools_toolspec, \"\") + messages\n    cookbook_tokens = renderer.build_generation_prompt(convo).to_ints()\n\n    # HuggingFace approach (pass OpenAI format to match tinker-cookbook's output)\n    hf_tokenizer = AutoTokenizer.from_pretrained(\n        \"moonshotai/Kimi-K2-Thinking\", trust_remote_code=True\n    )\n    hf_tokens = extract_token_ids(\n        hf_tokenizer.apply_chat_template(\n            messages, tools=tools_openai, tokenize=True, add_generation_prompt=True\n        )\n    )\n\n    # Compare tokens\n    cookbook_str = tokenizer.decode(cookbook_tokens)\n    hf_str = hf_tokenizer.decode(hf_tokens)\n\n    assert cookbook_tokens == hf_tokens, (\n        f\"Token mismatch!\\n\"\n        f\"Cookbook tokens: {len(cookbook_tokens)}\\n\"\n        f\"HF tokens: {len(hf_tokens)}\\n\"\n        f\"\\nCookbook string:\\n{cookbook_str[:500]}\\n\"\n        f\"\\nHF string:\\n{hf_str[:500]}\\n\"\n        f\"\\nFirst difference at token {_find_first_diff_index(cookbook_tokens, hf_tokens)}\"\n    )\n\n\ndef test_tool_declaration_string_matches_hf():\n    \"\"\"Test that tool declaration string matches HuggingFace exactly.\"\"\"\n    # ToolSpec format for tinker-cookbook\n    tools_toolspec: list[ToolSpec] = [\n        {\n            \"name\": \"test\",\n            \"description\": \"Test tool\",\n            \"parameters\": {\"type\": \"object\", \"properties\": {}},\n        }\n    ]\n    # OpenAI format for HF\n    tools_openai = [{\"type\": \"function\", \"function\": tool} for tool in tools_toolspec]\n    messages_list: list[Message] = [{\"role\": \"user\", \"content\": \"Test\"}]\n\n    # Tinker-cookbook\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n    convo = renderer.create_conversation_prefix_with_tools(tools_toolspec, \"\") + messages_list\n    cookbook_prompt = renderer.build_generation_prompt(convo)\n    cookbook_str = tokenizer.decode(cookbook_prompt.to_ints())\n\n    # HuggingFace (pass OpenAI format)\n    hf_tokenizer = AutoTokenizer.from_pretrained(\n        \"moonshotai/Kimi-K2-Thinking\", trust_remote_code=True\n    )\n    hf_str = hf_tokenizer.apply_chat_template(\n        messages_list, tools=tools_openai, tokenize=False, add_generation_prompt=True\n    )\n\n    assert cookbook_str == hf_str, (\n        f\"String mismatch!\\n\"\n        f\"\\n=== COOKBOOK ===\\n{cookbook_str[:800]}\\n\"\n        f\"\\n=== HF ===\\n{hf_str[:800]}\\n\"\n    )\n\n\ndef test_empty_tools_list():\n    \"\"\"Test that empty tools list doesn't cause issues.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    messages = renderer.create_conversation_prefix_with_tools([], \"\")\n\n    # Should have exactly one system message\n    assert len(messages) == 1\n    assert messages[0][\"role\"] == \"system\"\n\n\ndef test_custom_system_prompt_with_tools():\n    \"\"\"Test that custom system prompt is preserved when using tools.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    tools: list[ToolSpec] = [\n        {\"name\": \"test\", \"description\": \"Test\", \"parameters\": {\"type\": \"object\", \"properties\": {}}}\n    ]\n    custom_prompt = \"You are a helpful assistant specialized in weather.\"\n\n    messages = renderer.create_conversation_prefix_with_tools(tools, custom_prompt)\n\n    # Should have tool_declare first, then system with custom prompt\n    assert len(messages) == 2\n    assert messages[0][\"role\"] == \"tool_declare\"\n    assert messages[1][\"role\"] == \"system\"\n    assert messages[1][\"content\"] == custom_prompt\n\n\ndef _find_first_diff_index(list1, list2):\n    \"\"\"Helper to find first index where two lists differ.\"\"\"\n    for i, (a, b) in enumerate(zip(list1, list2)):\n        if a != b:\n            return i\n    return min(len(list1), len(list2))\n"
  },
  {
    "path": "tinker_cookbook/renderers/llama3.py",
    "content": "\"\"\"Renderer for Llama 3 chat format.\"\"\"\n\nimport tinker\n\nfrom tinker_cookbook.renderers.base import (\n    Message,\n    RenderContext,\n    RenderedMessage,\n    Renderer,\n    ensure_text,\n    parse_response_for_stop_token,\n)\n\n\nclass Llama3Renderer(Renderer):\n    \"\"\"Renderer for Llama 3 Instruct models.\n\n    Format::\n\n        <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n        You are a helpful AI assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n        What can you help me with?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n    Note: We intentionally differ from HF's stock Llama template:\n\n    - HF prepends \"Cutting Knowledge Date...\" to system messages; we don't\n      (add manually if needed)\n\n    Tool calling is NOT supported for Llama 3. The Llama 3 tool calling format\n    uses bare JSON without delimiters, making it impossible to reliably distinguish\n    tool calls from regular JSON content in model responses. Use a different model\n    or develop your own renderer if you need tool calling.\n    \"\"\"\n\n    @property\n    def has_extension_property(self) -> bool:\n        \"\"\"Llama3 satisfies the extension property - no content is stripped from history.\"\"\"\n        return True\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        role = message[\"role\"]\n        header_str = f\"<|start_header_id|>{role}<|end_header_id|>\\n\\n\"\n        output_str = ensure_text(message[\"content\"]) + \"<|eot_id|>\"\n\n        header = tinker.types.EncodedTextChunk(\n            tokens=self.tokenizer.encode(header_str, add_special_tokens=False)\n        )\n        output: list[tinker.ModelInputChunk] = [\n            tinker.types.EncodedTextChunk(\n                tokens=self.tokenizer.encode(output_str, add_special_tokens=False)\n            )\n        ]\n        return RenderedMessage(header=header, output=output)\n\n    @property\n    def _bos_tokens(self) -> list[int]:\n        return self.tokenizer.encode(\"<|begin_of_text|>\", add_special_tokens=False)\n\n    @property\n    def _end_message_token(self) -> int:\n        (token,) = self.tokenizer.encode(\"<|eot_id|>\", add_special_tokens=False)\n        return token\n\n    def get_stop_sequences(self) -> list[int]:\n        return [self._end_message_token]\n\n    def parse_response(self, response: list[int]) -> tuple[Message, bool]:\n        return parse_response_for_stop_token(response, self.tokenizer, self._end_message_token)\n"
  },
  {
    "path": "tinker_cookbook/renderers/nemotron3.py",
    "content": "\"\"\"\nNemotron-3 family renderer.\n\nNemotron-3 models (NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 and\nNVIDIA-Nemotron-3-Super-120B-A12B-BF16) use a chat format similar to Qwen3.5\n(im_start/im_end tokens, thinking blocks, XML-style tool calls) but differ in\nthe following ways:\n\n1. Tool declarations: Nemotron-3 uses structured XML inside <tools>...</tools>\n   (Qwen3.5 uses JSON per line).\n\n2. System message ordering: system_prompt comes BEFORE tools text (Qwen3.5\n   puts tools first).\n\n3. Empty think block scope: Nemotron-3's HF template prepends <think></think>\n   to ALL assistant messages that lack thinking, including historical ones\n   (Qwen3.5 only does this for messages after the last user query).\n\n4. Think block separator: one newline between </think> and text content\n   (Qwen3.5 uses two newlines).\n\n5. Disable-thinking generation suffix: <think></think> with no trailing\n   newlines (Qwen3.5 uses <think>\\\\n\\\\n</think>\\\\n\\\\n).\n\n6. Empty system message injection: Nemotron-3's HF template always outputs\n   a system message block even when none is provided (it always sets\n   system_message = \"\" which is \"defined\" in Jinja2). Our renderer\n   prepends an empty system message in build_generation_prompt and\n   build_supervised_example to match this behavior.\n\n\"\"\"\n\nimport dataclasses\nimport json\nfrom collections.abc import Mapping\n\nfrom tinker_cookbook.renderers.base import (\n    ImagePart,\n    Message,\n    RenderContext,\n    RenderedMessage,\n    Role,\n    TextPart,\n    ToolSpec,\n)\nfrom tinker_cookbook.renderers.qwen3_5 import Qwen3_5Renderer\n\n\ndef _render_extra_keys(obj: Mapping[str, object], handled_keys: set[str]) -> list[str]:\n    \"\"\"Render extra dict keys as XML, mirroring the HF template's render_extra_keys macro.\n\n    Dicts and lists are JSON-encoded; scalars are string-coerced.\n    \"\"\"\n    lines: list[str] = []\n    for key, value in obj.items():\n        if key in handled_keys:\n            continue\n        if isinstance(value, (dict, list)):\n            lines.append(f\"<{key}>{json.dumps(value)}</{key}>\")\n        else:\n            lines.append(f\"<{key}>{value!s}</{key}>\")\n    return lines\n\n\ndef _format_nemotron3_tool_declaration(tool: ToolSpec) -> str:\n    \"\"\"Format a single tool declaration in Nemotron-3's XML format.\n\n    Mirrors the Jinja template logic from chat_template.jinja, including the\n    render_extra_keys macro that outputs additional parameter fields beyond\n    the core name/type/description/enum set (e.g. default, minimum, items).\n    \"\"\"\n    lines = [\n        \"<function>\",\n        f\"<name>{tool['name']}</name>\",\n    ]\n    if tool.get(\"description\", \"\").strip():\n        lines.append(f\"<description>{tool['description'].strip()}</description>\")\n    lines.append(\"<parameters>\")\n    params = tool.get(\"parameters\") or {}\n    if isinstance(params, dict) and \"properties\" in params:\n        for param_name, param_fields in params[\"properties\"].items():\n            lines.append(\"<parameter>\")\n            lines.append(f\"<name>{param_name}</name>\")\n            if \"type\" in param_fields:\n                lines.append(f\"<type>{param_fields['type']!s}</type>\")\n            if \"description\" in param_fields:\n                lines.append(f\"<description>{param_fields['description'].strip()}</description>\")\n            if \"enum\" in param_fields:\n                lines.append(f\"<enum>{json.dumps(param_fields['enum'])}</enum>\")\n            lines.extend(_render_extra_keys(param_fields, {\"name\", \"type\", \"description\", \"enum\"}))\n            lines.append(\"</parameter>\")\n    if isinstance(params, dict):\n        lines.extend(_render_extra_keys(params, {\"type\", \"properties\", \"required\"}))\n    if isinstance(params, dict) and \"required\" in params:\n        lines.append(f\"<required>{json.dumps(params['required'])}</required>\")\n    lines.append(\"</parameters>\")\n    lines.extend(_render_extra_keys(tool, {\"type\", \"name\", \"description\", \"parameters\"}))\n    lines.append(\"</function>\")\n    return \"\\n\".join(lines)\n\n\nclass Nemotron3Renderer(Qwen3_5Renderer):\n    \"\"\"Renderer for Nemotron-3 models.\n\n    Subclasses Qwen3_5Renderer for the shared im_start/im_end/thinking/tool-call\n    infrastructure, overriding the parts that differ from Qwen3.5:\n\n    - _assistant_header_suffix: adds <think></think> to ALL assistant messages\n      whose thinking will NOT appear in the output.\n    - render_message: strips thinking only for messages before last_user_index\n      (matching HF template's truncate_history_thinking condition).\n    - _format_thinking_text: one separator newline after </think> (not two).\n    - _format_tool_calls_chunks: single newline prefix + trailing newline after\n      each </tool_call> (matching HF template format).\n    - parse_response: strips one newline separator after </think> (not two).\n    - create_conversation_prefix_with_tools: XML tool declarations, system\n      prompt before tools.\n    - build_generation_prompt / build_supervised_example: inject empty system\n      message when none is present, matching HF template behavior.\n    \"\"\"\n\n    def _normalize_messages(self, messages: list[Message]) -> list[Message]:\n        \"\"\"Prepend empty system message if not present.\n\n        Nemotron-3's HF template always outputs a system message block even\n        when none is provided (because it always sets system_message = \"\" which\n        is \"defined\" in Jinja2). This ensures our token sequence matches.\n        \"\"\"\n        if not messages or messages[0][\"role\"] != \"system\":\n            return [Message(role=\"system\", content=\"\")] + list(messages)\n        return messages\n\n    def build_generation_prompt(self, messages: list[Message], *args: object, **kwargs: object):  # type: ignore[override]\n        return super().build_generation_prompt(self._normalize_messages(messages), *args, **kwargs)  # type: ignore[arg-type]\n\n    def build_supervised_example(self, messages: list[Message], *args: object, **kwargs: object):  # type: ignore[override]\n        return super().build_supervised_example(self._normalize_messages(messages), *args, **kwargs)  # type: ignore[arg-type]\n\n    def _assistant_header_suffix(self, message: Message, ctx: RenderContext) -> str:\n        \"\"\"Prepend <think></think> when thinking will not appear in the output.\n\n        Nemotron-3's HF template prepends <think></think> to assistant message\n        content when there are no <think> tags in the output:\n        - Historical messages (idx < last_user_index): thinking is stripped,\n          so <think></think> is always prepended regardless of original content.\n        - Non-historical messages: prepend only if the message has no thinking.\n\n        When a historical message has non-empty text content, the HF template\n        produces \"<think></think>\\\\ntext\" (with a newline separator). This comes\n        from c.split('</think>')[-1] preserving the \\\\n in _format_thinking_text's\n        output. We add the \\\\n to the header suffix in that case.\n        \"\"\"\n        is_historical = ctx.idx < ctx.last_user_index\n        content = message.get(\"content\", \"\")\n        has_think = False\n        if isinstance(content, list):\n            has_think = any(p[\"type\"] == \"thinking\" for p in content)\n        elif isinstance(content, str):\n            has_think = \"<think>\" in content\n        # Non-historical with thinking: thinking will be in output, no prefix needed.\n        if has_think and not is_historical:\n            return \"\"\n        # For historical messages with stripped thinking and non-empty text:\n        # add \\n separator to match HF template's c.split('</think>')[-1] behavior.\n        # Exception: when the message has tool_calls, the HF template's tool_calls\n        # branch applies ``| trim`` (which binds tighter than ``~``) to the content\n        # before concatenation, stripping the leading \\n. So no separator in that case.\n        if is_historical and has_think:\n            has_nonempty_text = isinstance(content, list) and any(\n                p[\"type\"] == \"text\" and p.get(\"text\", \"\") for p in content\n            )\n            has_tool_calls = bool(message.get(\"tool_calls\"))\n            if has_nonempty_text and not has_tool_calls:\n                return \"<think></think>\\n\"\n        return \"<think></think>\"\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        \"\"\"Render a message, using idx < last_user_index for thinking stripping.\n\n        Nemotron-3's HF template strips thinking only for messages BEFORE the\n        last user message (truncate_history_thinking and loop.index0 < last_user_idx).\n        The base Qwen3VLRenderer uses `not ctx.is_last`, which incorrectly strips\n        thinking from assistant messages that follow tool responses.\n        \"\"\"\n        if message[\"role\"] == \"assistant\" and ctx.idx >= ctx.last_user_index and not ctx.is_last:\n            # Prevent thinking from being stripped: patch is_last=True so that\n            # the base class's `not ctx.is_last` condition evaluates to False.\n            ctx = dataclasses.replace(ctx, is_last=True)\n        return super().render_message(message, ctx)\n\n    def _format_thinking_text(self, thinking: str) -> str:\n        \"\"\"Nemotron-3 uses a single separator newline after </think>.\"\"\"\n        return f\"<think>\\n{thinking}\\n</think>\\n\"\n\n    def _wrap_qwen_tool_response_chunks(\n        self, chunks: list[ImagePart | TextPart]\n    ) -> list[ImagePart | TextPart]:\n        \"\"\"Wrap tool response with Nemotron-3's format.\n\n        Nemotron-3 HF template outputs '\\\\n</tool_response>\\\\n' (with trailing \\\\n),\n        while the base class uses '\\\\n</tool_response>' (no trailing \\\\n).\n        \"\"\"\n        return (\n            [TextPart(type=\"text\", text=\"<tool_response>\\n\")]\n            + chunks\n            + [TextPart(type=\"text\", text=\"\\n</tool_response>\\n\")]\n        )\n\n    def _format_tool_calls_chunks(self, message: Message) -> list[ImagePart | TextPart]:\n        \"\"\"Format tool_calls for Nemotron-3.\n\n        Differences from Qwen3.5:\n        - Single newline prefix (not two) before the first <tool_call>, unless the\n          preceding content (thinking) already ends with \\\\n.\n        - Trailing \\\\n after each </tool_call> (matching HF template line 156:\n          '</function>\\\\n</tool_call>\\\\n').\n\n        The prefix is omitted when the message has thinking-only content (no\n        non-empty text parts), because _format_thinking_text already ends with \\\\n.\n        \"\"\"\n        assert \"tool_calls\" in message\n        content = message.get(\"content\", \"\")\n        has_thinking = isinstance(content, list) and any(p[\"type\"] == \"thinking\" for p in content)\n        has_nonempty_text = isinstance(content, list) and any(\n            p[\"type\"] == \"text\" and p.get(\"text\", \"\") for p in content\n        )\n        # Thinking ends with \\n; only add \\n prefix if there's text after thinking\n        # (which won't end with \\n) or no thinking at all.\n        prefix = \"\" if (has_thinking and not has_nonempty_text) else \"\\n\"\n        calls = \"\".join(self._format_tool_call_xml(tc) + \"\\n\" for tc in message[\"tool_calls\"])\n        return [TextPart(type=\"text\", text=prefix + calls)]\n\n    def _postprocess_parsed_message(self, message: Message) -> None:\n        \"\"\"Strip one separator newline after </think> (not two like Qwen3.5).\n\n        Nemotron-3 uses a single ``\\\\n`` between ``</think>`` and text content,\n        while Qwen3.5 uses ``\\\\n\\\\n``. We strip the single newline here, then\n        delegate to the parent for thinking whitespace trimming and XML tool\n        call conversion. This ensures both ``parse_response`` and\n        ``_parse_response_for_streaming`` get the correct behavior.\n        \"\"\"\n        content = message.get(\"content\")\n        if isinstance(content, list):\n            seen_thinking = False\n            for p in content:\n                if p[\"type\"] == \"thinking\":\n                    seen_thinking = True\n                elif seen_thinking and p[\"type\"] == \"text\":\n                    # Strip exactly one separator newline (Nemotron-3's format).\n                    # Do this before super() so its \\n\\n check becomes a no-op.\n                    if p[\"text\"].startswith(\"\\n\"):\n                        p[\"text\"] = p[\"text\"][1:]\n                    break\n        super()._postprocess_parsed_message(message)\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        \"\"\"Create system message with Nemotron-3 XML tool specifications.\n\n        Nemotron-3 uses structured XML for tool declarations (unlike Qwen3.5's\n        JSON-per-line format). The system prompt also comes *before* the tools\n        block, matching the HF template:\n\n            <|im_start|>system\n            {system_prompt}\n\n            # Tools\n            ...\n        \"\"\"\n        tools_text = \"\"\n        if tools:\n            tool_declarations = \"\\n\".join(\n                _format_nemotron3_tool_declaration(tool) for tool in tools\n            )\n            tools_text = (\n                \"# Tools\\n\\n\"\n                \"You have access to the following functions:\\n\\n\"\n                \"<tools>\\n\"\n                f\"{tool_declarations}\\n\"\n                \"</tools>\\n\\n\"\n                \"If you choose to call a function ONLY reply in the following format with NO suffix:\\n\\n\"\n                \"<tool_call>\\n\"\n                \"<function=example_function_name>\\n\"\n                \"<parameter=example_parameter_1>\\n\"\n                \"value_1\\n\"\n                \"</parameter>\\n\"\n                \"<parameter=example_parameter_2>\\n\"\n                \"This is the value for the second parameter\\n\"\n                \"that can span\\n\"\n                \"multiple lines\\n\"\n                \"</parameter>\\n\"\n                \"</function>\\n\"\n                \"</tool_call>\\n\\n\"\n                \"<IMPORTANT>\\n\"\n                \"Reminder:\\n\"\n                \"- Function calls MUST follow the specified format: \"\n                \"an inner <function=...></function> block must be nested within \"\n                \"<tool_call></tool_call> XML tags\\n\"\n                \"- Required parameters MUST be specified\\n\"\n                \"- You may provide optional reasoning for your function call in natural language \"\n                \"BEFORE the function call, but NOT after\\n\"\n                \"- If there is no function call available, answer the question like normal with \"\n                \"your current knowledge and do not tell the user about function calls\\n\"\n                \"</IMPORTANT>\"\n            )\n\n        if tools_text:\n            # Nemotron-3 puts system_prompt BEFORE tools (opposite of Qwen3.5)\n            content = system_prompt + \"\\n\\n\" + tools_text if system_prompt else tools_text\n        else:\n            content = system_prompt\n\n        return [Message(role=\"system\", content=content)]\n\n\nclass Nemotron3DisableThinkingRenderer(Nemotron3Renderer):\n    \"\"\"Renderer for Nemotron-3 models with thinking disabled.\n\n    Matches the Nemotron-3 HF template with enable_thinking=False. The only\n    difference from Nemotron3Renderer is the generation suffix:\n    <think></think> (no trailing newlines) instead of <think>\\\\n.\n    \"\"\"\n\n    def _get_generation_suffix(self, role: Role, ctx: RenderContext) -> list[int]:\n        maybe_newline = \"\\n\" if ctx.idx > 0 else \"\"\n        header_str = f\"{maybe_newline}<|im_start|>{role}\\n<think></think>\"\n        return self.tokenizer.encode(header_str, add_special_tokens=False)\n"
  },
  {
    "path": "tinker_cookbook/renderers/nemotron3_test.py",
    "content": "\"\"\"\nTests for Nemotron-3 renderer.\n\nTests verify that the Nemotron3Renderer produces correct output:\n1. Generation prompt ends with <|im_start|>assistant\\n<think>\\n (thinking enabled)\n2. Disable-thinking variant ends with <|im_start|>assistant\\n<think></think>\n3. Tool declarations use Nemotron-3's structured XML format\n4. System prompt comes BEFORE tools in the system message\n5. <think></think> is prepended to ALL assistant messages without thinking (not just last)\n6. HF template compatibility for both build_generation_prompt and build_supervised_example\n\"\"\"\n\nimport json\n\nimport pytest\n\nfrom tinker_cookbook.renderers import Message, ToolCall, ToolSpec, get_renderer\nfrom tinker_cookbook.renderers.nemotron3 import (\n    Nemotron3DisableThinkingRenderer,\n    Nemotron3Renderer,\n    _format_nemotron3_tool_declaration,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\nNEMOTRON3_NANO_MODEL = \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\"\nNEMOTRON3_SUPER_MODEL = \"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16\"\nNEMOTRON3_TOKENIZER_PATH = \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\"\n\n\n# =============================================================================\n# Test Fixtures\n# =============================================================================\n\n\n@pytest.fixture(scope=\"module\")\ndef nemotron_tokenizer():\n    return get_tokenizer(NEMOTRON3_TOKENIZER_PATH)\n\n\n@pytest.fixture(scope=\"module\")\ndef nemotron_renderer(nemotron_tokenizer):\n    return get_renderer(\"nemotron3\", nemotron_tokenizer)\n\n\n@pytest.fixture(scope=\"module\")\ndef nemotron_renderer_disable_thinking(nemotron_tokenizer):\n    return get_renderer(\"nemotron3_disable_thinking\", nemotron_tokenizer)\n\n\ndef _hf_generation_tokens(tokenizer, hf_messages, tools=None, enable_thinking: bool = True):\n    \"\"\"Run HF apply_chat_template with generation prompt and return token list.\"\"\"\n    kwargs = {\"add_generation_prompt\": True, \"tokenize\": True, \"enable_thinking\": enable_thinking}\n    if tools is not None:\n        kwargs[\"tools\"] = tools\n    result = tokenizer.apply_chat_template(hf_messages, **kwargs)\n    # apply_chat_template may return BatchEncoding (dict-like) when tools are provided.\n    if hasattr(result, \"input_ids\"):\n        return list(result.input_ids)\n    return list(result)\n\n\ndef _hf_supervised_tokens(tokenizer, hf_messages, tools=None, enable_thinking: bool = True):\n    \"\"\"Run HF apply_chat_template without generation prompt, strip trailing newline, re-encode.\"\"\"\n    kwargs = {\"add_generation_prompt\": False, \"tokenize\": False, \"enable_thinking\": enable_thinking}\n    if tools is not None:\n        kwargs[\"tools\"] = tools\n    result = tokenizer.apply_chat_template(hf_messages, **kwargs)\n    # apply_chat_template with tokenize=False may return BatchEncoding when tools are provided.\n    hf_str = result.input_ids if hasattr(result, \"input_ids\") else result\n    assert isinstance(hf_str, str)\n    return tokenizer.encode(hf_str.rstrip(\"\\n\"), add_special_tokens=False)\n\n\n# =============================================================================\n# Test Conversations\n# =============================================================================\n\n\ndef get_basic_conversation_for_generation() -> list[Message]:\n    \"\"\"3-turn conversation ending with user message (for generation).\"\"\"\n    return [\n        Message(role=\"system\", content=\"You are a helpful assistant.\"),\n        Message(role=\"user\", content=\"Hello, how are you?\"),\n        Message(role=\"assistant\", content=\"I'm fine, thank you!\"),\n        Message(role=\"user\", content=\"What is the capital of France?\"),\n    ]\n\n\ndef get_basic_conversation_for_supervised() -> list[Message]:\n    \"\"\"2-turn conversation ending with assistant (for supervised).\"\"\"\n    return [\n        Message(role=\"system\", content=\"You are a helpful assistant.\"),\n        Message(role=\"user\", content=\"Hello, how are you?\"),\n        Message(role=\"assistant\", content=\"I'm fine, thank you!\"),\n    ]\n\n\ndef get_thinking_conversation_for_supervised() -> list[Message]:\n    \"\"\"Conversation with thinking content, ending with assistant.\"\"\"\n    return [\n        Message(role=\"system\", content=\"You are a helpful assistant.\"),\n        Message(role=\"user\", content=\"Solve 2+2.\"),\n        Message(\n            role=\"assistant\",\n            content=[\n                {\"type\": \"thinking\", \"thinking\": \"2 plus 2 equals 4.\"},\n                {\"type\": \"text\", \"text\": \"The answer is 4.\"},\n            ],\n        ),\n    ]\n\n\ndef get_multiturn_thinking_conversation() -> list[Message]:\n    \"\"\"Multi-turn with thinking in both assistant messages.\"\"\"\n    return [\n        Message(role=\"system\", content=\"You are a helpful assistant.\"),\n        Message(role=\"user\", content=\"First question.\"),\n        Message(\n            role=\"assistant\",\n            content=[\n                {\"type\": \"thinking\", \"thinking\": \"First turn reasoning.\"},\n                {\"type\": \"text\", \"text\": \"First answer.\"},\n            ],\n        ),\n        Message(role=\"user\", content=\"Second question.\"),\n        Message(\n            role=\"assistant\",\n            content=[\n                {\"type\": \"thinking\", \"thinking\": \"Second turn reasoning.\"},\n                {\"type\": \"text\", \"text\": \"Second answer.\"},\n            ],\n        ),\n    ]\n\n\ndef get_tool_spec() -> ToolSpec:\n    return ToolSpec(\n        name=\"get_weather\",\n        description=\"Get the current weather for a location\",\n        parameters={\n            \"type\": \"object\",\n            \"properties\": {\n                \"location\": {\n                    \"type\": \"string\",\n                    \"description\": \"The city and state, e.g. San Francisco, CA\",\n                },\n                \"unit\": {\n                    \"type\": \"string\",\n                    \"enum\": [\"celsius\", \"fahrenheit\"],\n                    \"description\": \"Temperature unit\",\n                },\n            },\n            \"required\": [\"location\"],\n        },\n    )\n\n\ndef get_rich_tool_spec() -> ToolSpec:\n    \"\"\"Tool spec with extra JSON Schema fields beyond name/type/description/enum.\"\"\"\n    return ToolSpec(\n        name=\"search\",\n        description=\"Search for items\",\n        parameters={\n            \"type\": \"object\",\n            \"properties\": {\n                \"query\": {\n                    \"type\": \"string\",\n                    \"description\": \"Search query\",\n                    \"default\": \"*\",\n                },\n                \"max_results\": {\n                    \"type\": \"integer\",\n                    \"description\": \"Maximum number of results\",\n                    \"minimum\": 1,\n                    \"maximum\": 100,\n                },\n                \"tags\": {\n                    \"type\": \"array\",\n                    \"description\": \"Filter tags\",\n                    \"items\": {\"type\": \"string\"},\n                },\n            },\n            \"required\": [\"query\"],\n            \"additionalProperties\": False,\n        },\n    )\n\n\ndef get_tool_call_conversation_for_generation() -> tuple[list[Message], list[ToolSpec]]:\n    tools = [get_tool_spec()]\n    tool_call = ToolCall(\n        id=\"call_abc123\",\n        function=ToolCall.FunctionBody(\n            name=\"get_weather\",\n            arguments='{\"location\": \"New York, NY\"}',\n        ),\n    )\n    messages: list[Message] = [\n        Message(role=\"user\", content=\"What's the weather in NYC?\"),\n        Message(\n            role=\"assistant\",\n            content=[\n                {\"type\": \"thinking\", \"thinking\": \"I need to check the weather in NYC.\"},\n                {\"type\": \"text\", \"text\": \"\"},\n            ],\n            tool_calls=[tool_call],\n        ),\n        Message(\n            role=\"tool\",\n            name=\"get_weather\",\n            tool_call_id=\"call_abc123\",\n            content='{\"temperature\": 72, \"condition\": \"sunny\"}',\n        ),\n    ]\n    return messages, tools\n\n\ndef get_historical_tool_call_with_nonempty_text_conversation() -> tuple[\n    list[Message], list[ToolSpec]\n]:\n    \"\"\"Conversation where a historical assistant message has thinking + non-empty text + tool_calls.\n\n    This is an edge case where the HF Jinja template's tool_calls branch applies\n    ``| trim`` to the content *before* concatenation with ``<think></think>``,\n    stripping the leading ``\\\\n`` that would otherwise be preserved in the\n    non-tool_calls branch. The result is ``<think></think>text`` (no newline)\n    for the historical message.\n\n    The first assistant message becomes historical because a later user message\n    follows the tool response + second assistant exchange.\n    \"\"\"\n    tools = [get_tool_spec()]\n    tool_call = ToolCall(\n        id=\"call_abc123\",\n        function=ToolCall.FunctionBody(\n            name=\"get_weather\",\n            arguments='{\"location\": \"New York, NY\"}',\n        ),\n    )\n    messages: list[Message] = [\n        Message(role=\"user\", content=\"What's the weather in NYC?\"),\n        # This assistant message has thinking + non-empty text + tool_calls\n        # and will be historical (before the last user message).\n        Message(\n            role=\"assistant\",\n            content=[\n                {\"type\": \"thinking\", \"thinking\": \"I need to check the weather.\"},\n                {\"type\": \"text\", \"text\": \"Let me check that for you.\"},\n            ],\n            tool_calls=[tool_call],\n        ),\n        Message(\n            role=\"tool\",\n            name=\"get_weather\",\n            tool_call_id=\"call_abc123\",\n            content='{\"temperature\": 72, \"condition\": \"sunny\"}',\n        ),\n        Message(\n            role=\"assistant\",\n            content=[\n                {\"type\": \"thinking\", \"thinking\": \"The weather is 72F and sunny.\"},\n                {\"type\": \"text\", \"text\": \"It's 72°F and sunny in NYC.\"},\n            ],\n        ),\n        Message(role=\"user\", content=\"Thanks!\"),\n    ]\n    return messages, tools\n\n\ndef get_tool_call_conversation_for_supervised() -> tuple[list[Message], list[ToolSpec]]:\n    tools = [get_tool_spec()]\n    tool_call = ToolCall(\n        id=\"call_abc123\",\n        function=ToolCall.FunctionBody(\n            name=\"get_weather\",\n            arguments='{\"location\": \"New York, NY\"}',\n        ),\n    )\n    messages: list[Message] = [\n        Message(role=\"user\", content=\"What's the weather in NYC?\"),\n        Message(\n            role=\"assistant\",\n            content=[\n                {\"type\": \"thinking\", \"thinking\": \"I need to check the weather in NYC.\"},\n                {\"type\": \"text\", \"text\": \"\"},\n            ],\n            tool_calls=[tool_call],\n        ),\n        Message(\n            role=\"tool\",\n            name=\"get_weather\",\n            tool_call_id=\"call_abc123\",\n            content='{\"temperature\": 72, \"condition\": \"sunny\"}',\n        ),\n        Message(\n            role=\"assistant\",\n            content=[\n                {\"type\": \"thinking\", \"thinking\": \"The weather is 72F and sunny.\"},\n                {\"type\": \"text\", \"text\": \"The weather in NYC is 72°F and sunny.\"},\n            ],\n        ),\n    ]\n    return messages, tools\n\n\n# =============================================================================\n# Tool Declaration Format Tests (no tokenizer required)\n# =============================================================================\n\n\ndef test_tool_declaration_xml_format():\n    \"\"\"Tool declarations use Nemotron-3's structured XML format.\"\"\"\n    tool = get_tool_spec()\n    declaration = _format_nemotron3_tool_declaration(tool)\n\n    assert \"<function>\" in declaration\n    assert \"<name>get_weather</name>\" in declaration\n    assert \"<description>Get the current weather for a location</description>\" in declaration\n    assert \"<parameters>\" in declaration\n    assert \"<parameter>\" in declaration\n    assert \"<name>location</name>\" in declaration\n    assert \"<type>string</type>\" in declaration\n    assert \"<name>unit</name>\" in declaration\n    assert \"<enum>\" in declaration\n    assert '\"celsius\"' in declaration\n    assert \"<required>\" in declaration\n    assert '\"location\"' in declaration\n    assert \"</function>\" in declaration\n\n\ndef test_tool_declaration_not_json_per_line():\n    \"\"\"Tool declarations should NOT use Qwen3.5's JSON-per-line format.\"\"\"\n    tool = get_tool_spec()\n    declaration = _format_nemotron3_tool_declaration(tool)\n    assert not declaration.strip().startswith(\"{\")\n    assert '\"name\": \"get_weather\"' not in declaration\n\n\ndef test_tool_declaration_minimal_tool():\n    \"\"\"Tool with no description and no parameters.\"\"\"\n    tool = ToolSpec(name=\"ping\", description=\"\", parameters={})\n    declaration = _format_nemotron3_tool_declaration(tool)\n    assert \"<name>ping</name>\" in declaration\n    assert \"<description>\" not in declaration\n    assert \"<parameter>\" not in declaration\n    assert \"<required>\" not in declaration\n\n\ndef test_tool_declaration_extra_schema_keys_match_hf(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Tool with extra JSON Schema fields (default, minimum, items, etc.) matches HF.\n\n    The HF Jinja template has a render_extra_keys macro that outputs additional\n    parameter fields beyond name/type/description/enum. This test verifies our\n    renderer handles those extra keys the same way.\n    \"\"\"\n    tools = [get_rich_tool_spec()]\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n    system_prompt = \"You are a helpful assistant.\"\n\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=system_prompt\n    )\n    messages = prefix + [Message(role=\"user\", content=\"Search for cats\")]\n    cookbook = nemotron_renderer.build_generation_prompt(messages).to_ints()\n\n    hf_messages = [\n        {\"role\": \"system\", \"content\": system_prompt},\n        {\"role\": \"user\", \"content\": \"Search for cats\"},\n    ]\n    hf = _hf_generation_tokens(nemotron_tokenizer, hf_messages, tools=openai_tools)\n\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_create_conversation_prefix_system_before_tools(nemotron_renderer):\n    \"\"\"System prompt should appear BEFORE tools block.\"\"\"\n    tools = [get_tool_spec()]\n    system_prompt = \"You are a helpful assistant.\"\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(tools, system_prompt)\n\n    assert len(prefix) == 1\n    assert prefix[0][\"role\"] == \"system\"\n    content = prefix[0][\"content\"]\n    assert isinstance(content, str)\n\n    sysprompt_idx = content.index(system_prompt)\n    tools_idx = content.index(\"# Tools\")\n    assert sysprompt_idx < tools_idx\n\n\ndef test_create_conversation_prefix_without_system_prompt(nemotron_renderer):\n    \"\"\"Without system_prompt, content starts directly with # Tools.\"\"\"\n    tools = [get_tool_spec()]\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(tools)\n    content = prefix[0][\"content\"]\n    assert isinstance(content, str)\n    assert content.startswith(\"# Tools\")\n\n\ndef test_create_conversation_prefix_xml_tool_format(nemotron_renderer):\n    \"\"\"Tool declarations in prefix use XML format, not JSON.\"\"\"\n    tools = [get_tool_spec()]\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(tools)\n    content = prefix[0][\"content\"]\n    assert \"<tools>\" in content\n    assert \"<function>\" in content\n    assert \"<name>get_weather</name>\" in content\n    assert \"<parameter>\" in content\n    assert \"</tools>\" in content\n    assert '{\"name\": \"get_weather\"' not in content\n\n\ndef test_create_conversation_prefix_no_tools(nemotron_renderer):\n    \"\"\"No tools: returns just the system_prompt.\"\"\"\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(\n        [], system_prompt=\"You are helpful.\"\n    )\n    assert prefix[0][\"content\"] == \"You are helpful.\"\n\n\n# =============================================================================\n# Generation Prompt Tests\n# =============================================================================\n\n\ndef test_generation_prompt_ends_with_think(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Nemotron3Renderer prefills with <think>\\\\n.\"\"\"\n    messages = get_basic_conversation_for_generation()\n    decoded = nemotron_tokenizer.decode(\n        nemotron_renderer.build_generation_prompt(messages).to_ints()\n    )\n    assert decoded.endswith(\"<|im_start|>assistant\\n<think>\\n\")\n\n\ndef test_disable_thinking_generation_prompt(nemotron_tokenizer, nemotron_renderer_disable_thinking):\n    \"\"\"Nemotron3DisableThinkingRenderer prefills with <think></think>.\"\"\"\n    messages = get_basic_conversation_for_generation()\n    decoded = nemotron_tokenizer.decode(\n        nemotron_renderer_disable_thinking.build_generation_prompt(messages).to_ints()\n    )\n    assert decoded.endswith(\"<|im_start|>assistant\\n<think></think>\")\n\n\ndef test_custom_prefill_overrides_think(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Custom prefill overrides the default <think>\\\\n.\"\"\"\n    messages = get_basic_conversation_for_generation()\n    decoded = nemotron_tokenizer.decode(\n        nemotron_renderer.build_generation_prompt(messages, prefill=\"Sure, \").to_ints()\n    )\n    assert decoded.endswith(\"Sure, \")\n    assert not decoded.endswith(\"<think>\\n\")\n\n\n# =============================================================================\n# HF Template Compatibility Tests — Generation\n# =============================================================================\n\n\ndef test_basic_conversation_generation_matches_hf(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Basic conversation generation matches HF template.\"\"\"\n    messages = get_basic_conversation_for_generation()\n    cookbook = nemotron_renderer.build_generation_prompt(messages).to_ints()\n    hf = _hf_generation_tokens(\n        nemotron_tokenizer,\n        [nemotron_renderer.to_openai_message(m) for m in messages],\n    )\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_disable_thinking_generation_matches_hf(\n    nemotron_tokenizer, nemotron_renderer_disable_thinking\n):\n    \"\"\"Disable-thinking generation matches HF with enable_thinking=False.\"\"\"\n    messages = [\n        Message(role=\"system\", content=\"You are helpful.\"),\n        Message(role=\"user\", content=\"Hello\"),\n        Message(role=\"assistant\", content=\"Hi!\"),\n        Message(role=\"user\", content=\"What is 2+2?\"),\n    ]\n    r = nemotron_renderer_disable_thinking\n    cookbook = r.build_generation_prompt(messages).to_ints()\n    hf = _hf_generation_tokens(\n        nemotron_tokenizer,\n        [r.to_openai_message(m) for m in messages],\n        enable_thinking=False,\n    )\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\n# =============================================================================\n# HF Template Compatibility Tests — Supervised\n# =============================================================================\n\n\ndef test_basic_conversation_supervised_matches_hf(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Basic supervised example matches HF template (no gen prompt).\"\"\"\n    messages = get_basic_conversation_for_supervised()\n    cookbook = nemotron_renderer.build_supervised_example(messages)[0].to_ints()\n    hf = _hf_supervised_tokens(\n        nemotron_tokenizer,\n        [nemotron_renderer.to_openai_message(m) for m in messages],\n    )\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_thinking_conversation_supervised_matches_hf(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Supervised example with thinking content matches HF template.\"\"\"\n    messages = get_thinking_conversation_for_supervised()\n    cookbook = nemotron_renderer.build_supervised_example(messages)[0].to_ints()\n    hf = _hf_supervised_tokens(\n        nemotron_tokenizer,\n        [nemotron_renderer.to_openai_message(m) for m in messages],\n    )\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_multiturn_thinking_supervised_matches_hf(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Multi-turn with thinking in both assistant messages matches HF template.\n\n    Nemotron-3's HF template truncates thinking in historical messages to\n    <think></think>. This test verifies our renderer does the same.\n    \"\"\"\n    messages = get_multiturn_thinking_conversation()\n    cookbook = nemotron_renderer.build_supervised_example(messages)[0].to_ints()\n    hf = _hf_supervised_tokens(\n        nemotron_tokenizer,\n        [nemotron_renderer.to_openai_message(m) for m in messages],\n    )\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_think_block_added_to_all_assistant_history(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"<think></think> is prepended to historical assistant messages without thinking.\"\"\"\n    messages = get_basic_conversation_for_generation()  # ends with user, has one assistant\n    decoded = nemotron_tokenizer.decode(\n        nemotron_renderer.build_generation_prompt(messages).to_ints()\n    )\n    # The historical assistant message should have <think></think> prepended\n    assert \"<think></think>I'm fine, thank you!\" in decoded\n\n\n# =============================================================================\n# HF Template Compatibility Tests — Tool Declarations\n# =============================================================================\n\n\n@pytest.mark.parametrize(\"build_mode\", [\"generation\", \"supervised\"])\ndef test_tool_declaration_matches_hf(build_mode: str, nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Tool declarations match HF template output.\"\"\"\n    tools = [get_tool_spec()]\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n    system_prompt = \"You are a helpful assistant.\"\n\n    prefix_messages = nemotron_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=system_prompt\n    )\n    user_msg = Message(role=\"user\", content=\"What's the weather in NYC?\")\n\n    hf_messages = [\n        {\"role\": \"system\", \"content\": system_prompt},\n        {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n    ]\n\n    if build_mode == \"generation\":\n        cookbook = nemotron_renderer.build_generation_prompt(prefix_messages + [user_msg]).to_ints()\n        hf = _hf_generation_tokens(nemotron_tokenizer, hf_messages, tools=openai_tools)\n    else:\n        assistant_msg = Message(role=\"assistant\", content=\"Let me check that for you.\")\n        cookbook = nemotron_renderer.build_supervised_example(\n            prefix_messages + [user_msg, assistant_msg]\n        )[0].to_ints()\n        hf_messages.append({\"role\": \"assistant\", \"content\": \"Let me check that for you.\"})\n        hf = _hf_supervised_tokens(nemotron_tokenizer, hf_messages, tools=openai_tools)\n\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_tool_call_conversation_generation_matches_hf(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Tool call + tool response conversation (generation) matches HF template.\"\"\"\n    messages, tools = get_tool_call_conversation_for_generation()\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n    system_prompt = \"You are a helpful assistant.\"\n\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=system_prompt\n    )\n    cookbook = nemotron_renderer.build_generation_prompt(prefix + messages).to_ints()\n\n    hf_messages = [\n        {\"role\": \"system\", \"content\": system_prompt},\n        *[nemotron_renderer.to_openai_message(m) for m in messages],\n    ]\n    hf = _hf_generation_tokens(nemotron_tokenizer, hf_messages, tools=openai_tools)\n\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_tool_call_conversation_supervised_matches_hf(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"Complete tool call conversation (supervised) matches HF template.\"\"\"\n    messages, tools = get_tool_call_conversation_for_supervised()\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n    system_prompt = \"You are a helpful assistant.\"\n\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=system_prompt\n    )\n    cookbook = nemotron_renderer.build_supervised_example(prefix + messages)[0].to_ints()\n\n    hf_messages = [\n        {\"role\": \"system\", \"content\": system_prompt},\n        *[nemotron_renderer.to_openai_message(m) for m in messages],\n    ]\n    hf = _hf_supervised_tokens(nemotron_tokenizer, hf_messages, tools=openai_tools)\n\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_historical_tool_call_with_nonempty_text_generation_matches_hf(\n    nemotron_tokenizer, nemotron_renderer\n):\n    \"\"\"Historical tool_call message with thinking + non-empty text matches HF.\n\n    In the HF Jinja template's tool_calls branch, ``| trim`` binds tighter than\n    ``~``, so the leading ``\\\\n`` from the content is stripped before concatenation\n    with ``<think></think>``, producing ``<think></think>text`` (no newline).\n    This differs from the non-tool_calls branch which preserves the newline.\n    \"\"\"\n    messages, tools = get_historical_tool_call_with_nonempty_text_conversation()\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n    system_prompt = \"You are a helpful assistant.\"\n\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=system_prompt\n    )\n    cookbook = nemotron_renderer.build_generation_prompt(prefix + messages).to_ints()\n\n    hf_messages = [\n        {\"role\": \"system\", \"content\": system_prompt},\n        *[nemotron_renderer.to_openai_message(m) for m in messages],\n    ]\n    hf = _hf_generation_tokens(nemotron_tokenizer, hf_messages, tools=openai_tools)\n\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\ndef test_historical_tool_call_with_nonempty_text_supervised_matches_hf(\n    nemotron_tokenizer, nemotron_renderer\n):\n    \"\"\"Supervised version of the historical tool_call + non-empty text edge case.\"\"\"\n    messages, tools = get_historical_tool_call_with_nonempty_text_conversation()\n    openai_tools = [{\"type\": \"function\", \"function\": tool} for tool in tools]\n    system_prompt = \"You are a helpful assistant.\"\n\n    prefix = nemotron_renderer.create_conversation_prefix_with_tools(\n        tools, system_prompt=system_prompt\n    )\n    cookbook = nemotron_renderer.build_supervised_example(prefix + messages)[0].to_ints()\n\n    hf_messages = [\n        {\"role\": \"system\", \"content\": system_prompt},\n        *[nemotron_renderer.to_openai_message(m) for m in messages],\n    ]\n    hf = _hf_supervised_tokens(nemotron_tokenizer, hf_messages, tools=openai_tools)\n\n    assert cookbook == hf, (\n        f\"Cookbook: {nemotron_tokenizer.decode(cookbook)}\\nHF: {nemotron_tokenizer.decode(hf)}\"\n    )\n\n\n# =============================================================================\n# Parse Response Tests\n# =============================================================================\n\n\ndef test_parse_response_plain_text(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"parse_response decodes a plain text response (no thinking).\"\"\"\n    tokens = nemotron_tokenizer.encode(\"The answer is 42.<|im_end|>\", add_special_tokens=False)\n    message, success = nemotron_renderer.parse_response(tokens)\n    assert success\n    from tinker_cookbook.renderers import get_text_content\n\n    assert \"42\" in get_text_content(message)\n\n\ndef test_parse_response_with_thinking(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"parse_response extracts thinking content from the response.\"\"\"\n    # Simulates what the model generates after the <think>\\n prefill\n    response_text = \"I should reason carefully.\\n</think>\\nThe answer is 42.<|im_end|>\"\n    tokens = nemotron_tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = nemotron_renderer.parse_response(tokens)\n\n    assert success\n    content = message.get(\"content\")\n    assert isinstance(content, list)\n    thinking_parts = [p for p in content if p[\"type\"] == \"thinking\"]\n    text_parts = [p for p in content if p[\"type\"] == \"text\"]\n    assert len(thinking_parts) == 1\n    assert \"reason\" in thinking_parts[0][\"thinking\"]\n    assert len(text_parts) == 1\n    assert \"42\" in text_parts[0][\"text\"]\n\n\ndef test_parse_response_for_streaming_with_thinking(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"_parse_response_for_streaming preserves the single \\\\n separator after </think>.\n\n    The inherited _parse_response_for_streaming calls _postprocess_parsed_message\n    which strips separator newlines. For Nemotron-3, the separator is one \\\\n (not\n    two like Qwen3.5), so the text after thinking should NOT start with \\\\n (the\n    single separator should be stripped) and must not lose content by over-stripping.\n    \"\"\"\n    # Include <think>\\n prefix — in real streaming, _normalize_response_tokens\n    # prepends this before _parse_response_for_streaming is called.\n    response_text = \"<think>\\nI should reason carefully.\\n</think>\\nThe answer is 42.<|im_end|>\"\n    tokens = nemotron_tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = nemotron_renderer._parse_response_for_streaming(tokens)\n\n    assert success\n    content = message.get(\"content\")\n    assert isinstance(content, list)\n    thinking_parts = [p for p in content if p[\"type\"] == \"thinking\"]\n    text_parts = [p for p in content if p[\"type\"] == \"text\"]\n    assert len(thinking_parts) == 1\n    assert \"reason\" in thinking_parts[0][\"thinking\"]\n    assert len(text_parts) == 1\n    # The text should start with \"The answer\" — the \\n separator should be stripped,\n    # not left as-is (0 newlines stripped) or double-stripped.\n    assert text_parts[0][\"text\"].startswith(\"The answer\"), (\n        f\"Expected text to start with 'The answer' but got: {text_parts[0]['text']!r}\"\n    )\n\n\ndef test_parse_response_tool_call(nemotron_tokenizer, nemotron_renderer):\n    \"\"\"parse_response parses XML-format tool calls.\"\"\"\n    tool_call_text = (\n        \"\\n</think>\\n\"\n        \"<tool_call>\\n\"\n        \"<function=get_weather>\\n\"\n        \"<parameter=location>\\n\"\n        \"New York, NY\\n\"\n        \"</parameter>\\n\"\n        \"</function>\\n\"\n        \"</tool_call><|im_end|>\"\n    )\n    tokens = nemotron_tokenizer.encode(tool_call_text, add_special_tokens=False)\n    message, success = nemotron_renderer.parse_response(tokens)\n\n    assert success\n    tool_calls = message.get(\"tool_calls\", [])\n    assert len(tool_calls) == 1\n    assert tool_calls[0].function.name == \"get_weather\"\n    args = json.loads(tool_calls[0].function.arguments)\n    assert args[\"location\"] == \"New York, NY\"\n\n\n# =============================================================================\n# Renderer Identity Tests\n# =============================================================================\n\n\ndef test_renderer_types(nemotron_renderer, nemotron_renderer_disable_thinking):\n    assert isinstance(nemotron_renderer, Nemotron3Renderer)\n    assert isinstance(nemotron_renderer_disable_thinking, Nemotron3DisableThinkingRenderer)\n\n\ndef test_renderer_is_not_qwen35(nemotron_renderer):\n    from tinker_cookbook.renderers.qwen3_5 import Qwen3_5Renderer\n\n    assert type(nemotron_renderer) is not Qwen3_5Renderer\n"
  },
  {
    "path": "tinker_cookbook/renderers/parsing_test.py",
    "content": "\"\"\"Tests for shared parsing utilities and cross-renderer roundtrip behavior.\"\"\"\n\nimport pytest\n\nfrom tinker_cookbook.renderers import (\n    ContentPart,\n    DeepSeekV3ThinkingRenderer,\n    GptOssRenderer,\n    Message,\n    Qwen3Renderer,\n    RenderContext,\n    TextPart,\n    ThinkingPart,\n    format_content_as_string,\n    parse_content_blocks,\n)\nfrom tinker_cookbook.renderers.base import (\n    ToolCall,\n    UnparsedToolCall,\n    Utf8TokenDecoder,\n    _longest_matching_suffix_prefix,\n    ensure_list,\n)\nfrom tinker_cookbook.renderers.deepseek_v3 import DeepSeekV3DisableThinkingRenderer\nfrom tinker_cookbook.renderers.kimi_k2 import KimiK2Renderer\nfrom tinker_cookbook.renderers.kimi_k25 import KimiK25Renderer\nfrom tinker_cookbook.renderers.qwen3_5 import Qwen3_5DisableThinkingRenderer, Qwen3_5Renderer\nfrom tinker_cookbook.renderers.testing_utils import skip_if_deepseek_tokenizer_bug\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n# =============================================================================\n# parse_content_blocks Tests\n# =============================================================================\n\n\ndef test_parse_content_blocks_no_special_tags():\n    \"\"\"Test parse_content_blocks returns None when no special tags.\"\"\"\n    parts = parse_content_blocks(\"Just plain text\")\n    assert parts is None\n\n\ndef test_parse_content_blocks_single_think_block():\n    \"\"\"Test parse_content_blocks with a single think block.\"\"\"\n    result = parse_content_blocks(\"<think>reasoning</think>visible answer\")\n    assert result is not None\n    parts, tool_calls = result\n\n    assert len(parts) == 2\n    assert parts[0][\"type\"] == \"thinking\"\n    assert parts[0][\"thinking\"] == \"reasoning\"  # type: ignore[typeddict-item]\n    assert parts[1][\"type\"] == \"text\"\n    assert parts[1][\"text\"] == \"visible answer\"  # type: ignore[typeddict-item]\n    assert tool_calls == []\n\n\ndef test_parse_content_blocks_multiple_think_blocks():\n    \"\"\"Test parse_content_blocks with multiple interleaved think blocks.\"\"\"\n    content = \"<think>step 1</think>partial<think>step 2</think>final\"\n    result = parse_content_blocks(content)\n    assert result is not None\n    parts, tool_calls = result\n\n    assert len(parts) == 4\n    assert parts[0] == ThinkingPart(type=\"thinking\", thinking=\"step 1\")\n    assert parts[1] == TextPart(type=\"text\", text=\"partial\")\n    assert parts[2] == ThinkingPart(type=\"thinking\", thinking=\"step 2\")\n    assert parts[3] == TextPart(type=\"text\", text=\"final\")\n    assert tool_calls == []\n\n\ndef test_parse_content_blocks_empty_blocks_omitted():\n    \"\"\"Test parse_content_blocks omits empty think blocks.\"\"\"\n    result = parse_content_blocks(\"<think></think>visible\")\n    assert result is not None\n    parts, tool_calls = result\n\n    assert len(parts) == 1\n    assert parts[0][\"type\"] == \"text\"\n    assert parts[0][\"text\"] == \"visible\"  # type: ignore[typeddict-item]\n    assert tool_calls == []\n\n\ndef test_parse_content_blocks_whitespace_handling():\n    \"\"\"Test parse_content_blocks preserves whitespace for identity roundtrip.\"\"\"\n    result = parse_content_blocks(\"<think>  thinking  </think>  answer  \")\n    assert result is not None\n    parts, tool_calls = result\n\n    assert len(parts) == 2\n    # Whitespace is preserved exactly for identity roundtrip\n    assert parts[0][\"type\"] == \"thinking\" and parts[0][\"thinking\"] == \"  thinking  \"  # type: ignore[typeddict-item]\n    assert parts[1][\"type\"] == \"text\" and parts[1][\"text\"] == \"  answer  \"  # type: ignore[typeddict-item]\n    assert tool_calls == []\n\n\ndef test_parse_content_blocks_tool_call_only():\n    \"\"\"Test parse_content_blocks parses tool calls into separate list.\"\"\"\n    content = '<tool_call>{\"name\": \"search\", \"arguments\": {\"query\": \"test\"}}</tool_call>'\n    result = parse_content_blocks(content)\n    assert result is not None\n    parts, tool_calls = result\n\n    assert len(parts) == 0\n    assert len(tool_calls) == 1\n    assert isinstance(tool_calls[0], ToolCall)\n    assert tool_calls[0].function.name == \"search\"\n    assert tool_calls[0].function.arguments == '{\"query\": \"test\"}'\n\n\ndef test_parse_content_blocks_interleaved():\n    \"\"\"Test parse_content_blocks handles interleaved think and tool_call blocks.\"\"\"\n    content = '<think>Let me search</think>Searching...<tool_call>{\"name\": \"search\", \"arguments\": {\"q\": \"test\"}}</tool_call>Done'\n    result = parse_content_blocks(content)\n    assert result is not None\n    parts, tool_calls = result\n\n    # Content parts: think, text before tool_call, text after tool_call\n    assert len(parts) == 3\n    assert parts[0] == ThinkingPart(type=\"thinking\", thinking=\"Let me search\")\n    assert parts[1] == TextPart(type=\"text\", text=\"Searching...\")\n    assert parts[2] == TextPart(type=\"text\", text=\"Done\")\n\n    # Tool call extracted separately\n    assert len(tool_calls) == 1\n    assert isinstance(tool_calls[0], ToolCall)\n    assert tool_calls[0].function.name == \"search\"\n\n\ndef test_parse_content_blocks_invalid_tool_call():\n    \"\"\"Test parse_content_blocks handles invalid tool call JSON as UnparsedToolCall.\"\"\"\n    content = \"<tool_call>not valid json</tool_call>text after\"\n    result = parse_content_blocks(content)\n    assert result is not None\n    parts, tool_calls = result\n\n    # Text after tool call is captured in content parts\n    assert len(parts) == 1\n    assert parts[0] == TextPart(type=\"text\", text=\"text after\")\n\n    # Invalid tool call is in tool_calls list as UnparsedToolCall\n    assert len(tool_calls) == 1\n    assert isinstance(tool_calls[0], UnparsedToolCall)\n    assert \"Invalid JSON\" in tool_calls[0].error\n\n\ndef test_format_content_as_string_roundtrip():\n    \"\"\"Formatted content should be parseable back to original.\"\"\"\n    content = [\n        ThinkingPart(type=\"thinking\", thinking=\"reasoning\"),\n        TextPart(type=\"text\", text=\"answer\"),\n    ]\n    # Use empty separator for true roundtrip (default separator adds newlines between parts)\n    formatted = format_content_as_string(content, separator=\"\")\n    result = parse_content_blocks(formatted)\n    assert result is not None\n    parts, tool_calls = result\n    assert parts == content\n    assert tool_calls == []\n\n\n# =============================================================================\n# _longest_matching_suffix_prefix Tests\n# =============================================================================\n\n\ndef test_longest_matching_suffix_prefix():\n    \"\"\"Test the suffix-prefix matching helper function.\"\"\"\n    # No match cases\n    assert _longest_matching_suffix_prefix(\"hello\", \"<think>\") == 0\n    assert _longest_matching_suffix_prefix(\"hello world\", \"<think>\") == 0\n    assert _longest_matching_suffix_prefix(\"\", \"<think>\") == 0\n\n    # Partial matches\n    assert _longest_matching_suffix_prefix(\"hello<\", \"<think>\") == 1\n    assert _longest_matching_suffix_prefix(\"hello<t\", \"<think>\") == 2\n    assert _longest_matching_suffix_prefix(\"hello<th\", \"<think>\") == 3\n    assert _longest_matching_suffix_prefix(\"hello<thi\", \"<think>\") == 4\n    assert _longest_matching_suffix_prefix(\"hello<thin\", \"<think>\") == 5\n    assert _longest_matching_suffix_prefix(\"hello<think\", \"<think>\") == 6\n\n    # Non-matching partial (doesn't match prefix)\n    assert _longest_matching_suffix_prefix(\"hello<thx\", \"<think>\") == 0\n    assert _longest_matching_suffix_prefix(\"hello<tx\", \"<think>\") == 0\n\n    # For </think>\n    assert _longest_matching_suffix_prefix(\"thinking</\", \"</think>\") == 2\n    assert _longest_matching_suffix_prefix(\"thinking</t\", \"</think>\") == 3\n    assert _longest_matching_suffix_prefix(\"thinking</think\", \"</think>\") == 7\n\n    # Edge: text shorter than tag\n    assert _longest_matching_suffix_prefix(\"<t\", \"<think>\") == 2\n    assert _longest_matching_suffix_prefix(\"<\", \"<think>\") == 1\n\n\n# =============================================================================\n# Utf8TokenDecoder Tests\n# =============================================================================\n\n\ndef test_utf8_decoder_non_monotonic_decodability():\n    \"\"\"Test that Utf8TokenDecoder handles non-monotonic decodability.\n\n    This test would FAIL with binary search but PASSES with backwards scan.\n\n    The scenario: tokens [A, B, C, D] where:\n    - decode([A]) fails (partial UTF-8)\n    - decode([A, B]) fails (still partial)\n    - decode([A, B, C]) succeeds (completes the character!)\n    - decode([A, B, C, D]) fails (D starts a new partial)\n\n    Binary search would:\n    - Try mid=2: decode([A,B]) fails → high=1\n    - Try mid=1: decode([A]) fails → high=0\n    - Return None (WRONG - missed that [:3] works!)\n\n    Backwards scan:\n    - Try removing 1 token: decode([A,B,C]) succeeds → return it ✓\n    \"\"\"\n\n    class MockTokenizer:\n        \"\"\"Mock tokenizer that simulates non-monotonic UTF-8 decodability.\"\"\"\n\n        def decode(self, tokens: list[int]) -> str:\n            # Simulate: tokens 1,2,3 together form valid UTF-8,\n            # but subsets [1], [1,2] are invalid, and [1,2,3,4] is invalid\n            # (token 4 starts a new incomplete sequence)\n            if tokens == [1, 2, 3]:\n                return \"✓\"  # Only this combination decodes\n            elif tokens == [1, 2, 3, 4] or 4 in tokens:\n                raise ValueError(\"Incomplete UTF-8: token 4 is partial\")\n            else:\n                raise ValueError(f\"Incomplete UTF-8: {tokens}\")\n\n    decoder = Utf8TokenDecoder(MockTokenizer())  # type: ignore[arg-type]\n\n    # Feed all 4 tokens at once\n    result = decoder.decode([1, 2, 3, 4])\n\n    # Should decode [1,2,3] and buffer [4]\n    assert result == \"✓\", f\"Expected '✓' but got {result!r}\"\n    assert decoder._pending_tokens == [4], f\"Expected [4] pending but got {decoder._pending_tokens}\"\n\n\ndef test_utf8_decoder_with_real_tokenizer_ascii():\n    \"\"\"Test Utf8TokenDecoder with real tokenizer on ASCII text.\n\n    Note: Many tokenizers (including tiktoken-based ones like Kimi) don't throw\n    exceptions for incomplete UTF-8 - they return replacement characters (â).\n    This means our exception-based buffering doesn't help for those tokenizers.\n\n    However, for ASCII text (single-byte UTF-8), there's no splitting issue,\n    so the decoder should work correctly.\n    \"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n\n    # ASCII-only text won't have UTF-8 splitting issues\n    test_str = \"Hello World! How are you today?\"\n    tokens = tokenizer.encode(test_str, add_special_tokens=False)\n\n    # Feed tokens one at a time and collect decoded text\n    decoder = Utf8TokenDecoder(tokenizer)\n    decoded_parts = []\n    for token in tokens:\n        result = decoder.decode([token])\n        if result is not None:\n            decoded_parts.append(result)\n\n    # Flush any remaining\n    remaining = decoder.flush()\n    if remaining:\n        decoded_parts.append(remaining)\n\n    # Concatenated result should match original\n    full_decoded = \"\".join(decoded_parts)\n    assert full_decoded == test_str, f\"Expected {test_str!r} but got {full_decoded!r}\"\n\n\ndef test_utf8_decoder_handles_replacement_chars():\n    \"\"\"Test that Utf8TokenDecoder handles tokenizers that return replacement chars.\n\n    Tiktoken-based tokenizers (like Kimi's) return U+FFFD (replacement character)\n    for incomplete UTF-8 instead of raising exceptions. The decoder detects these\n    and buffers tokens until the sequence completes.\n    \"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n\n    # The emoji 🎉 is encoded as multiple tokens\n    test_str = \"🎉\"\n    tokens = tokenizer.encode(test_str, add_special_tokens=False)\n\n    # Verify tokens individually decode to replacement/garbled chars (confirming tiktoken behavior)\n    for tok in tokens:\n        decoded = tokenizer.decode([tok])\n        assert decoded != test_str, (\n            f\"Expected garbled output for partial token {tok}, got {decoded!r}\"\n        )\n\n    # Now test that our decoder handles this correctly\n    decoder = Utf8TokenDecoder(tokenizer)\n    decoded_parts = []\n\n    for token in tokens:\n        result = decoder.decode([token])\n        if result is not None:\n            decoded_parts.append(result)\n\n    # Flush any remaining\n    remaining = decoder.flush()\n    if remaining:\n        decoded_parts.append(remaining)\n\n    # The concatenated result should be the original emoji (no replacement chars)\n    full_decoded = \"\".join(decoded_parts)\n    assert full_decoded == test_str, f\"Expected {test_str!r} but got {full_decoded!r}\"\n\n\ndef test_utf8_decoder_mixed_ascii_and_emoji():\n    \"\"\"Test streaming with mixed ASCII and multi-byte Unicode.\"\"\"\n    tokenizer = get_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n\n    # Mix of ASCII and emoji\n    test_str = \"Hello 🎉 World 🌍!\"\n    tokens = tokenizer.encode(test_str, add_special_tokens=False)\n\n    decoder = Utf8TokenDecoder(tokenizer)\n    decoded_parts = []\n\n    for token in tokens:\n        result = decoder.decode([token])\n        if result is not None:\n            decoded_parts.append(result)\n\n    remaining = decoder.flush()\n    if remaining:\n        decoded_parts.append(remaining)\n\n    full_decoded = \"\".join(decoded_parts)\n    assert full_decoded == test_str, f\"Expected {test_str!r} but got {full_decoded!r}\"\n    assert \"â\" not in full_decoded, \"Should not contain replacement characters\"\n\n\n# =============================================================================\n# Cross-Renderer Parse Correspondence Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_cls,renderer_kwargs\",\n    [\n        (\"deepseek-ai/DeepSeek-V3.1\", DeepSeekV3ThinkingRenderer, {}),\n        (\"deepseek-ai/DeepSeek-V3.1\", DeepSeekV3DisableThinkingRenderer, {}),\n        (\n            \"openai/gpt-oss-20b\",\n            GptOssRenderer,\n            {\"use_system_prompt\": True, \"reasoning_effort\": \"medium\"},\n        ),\n        (\"Qwen/Qwen3-30B-A3B\", Qwen3Renderer, {}),\n        (\"Qwen/Qwen3.5-35B-A3B\", Qwen3_5Renderer, {}),\n        (\"Qwen/Qwen3.5-35B-A3B\", Qwen3_5DisableThinkingRenderer, {}),\n        (\"moonshotai/Kimi-K2-Thinking\", KimiK2Renderer, {}),\n        (\"moonshotai/Kimi-K2.5\", KimiK25Renderer, {}),\n    ],\n)\ndef test_thinking_generation_parse_correspondence(model_name, renderer_cls, renderer_kwargs):\n    \"\"\"Test that parse_response handles sampled output after thinking prefill.\n\n    Pattern for thinking model tests:\n    1. Build generation prompt (may include thinking prefill)\n    2. Render expected message to get full response tokens\n    3. Strip prefill to simulate what sampling returns\n    4. Parse continuation → should recover the expected message\n    5. Roundtrip: prompt + continuation = full supervised example\n    \"\"\"\n    skip_if_deepseek_tokenizer_bug(model_name)\n    tokenizer = get_tokenizer(model_name)\n    renderer = renderer_cls(tokenizer, **renderer_kwargs)\n\n    # User message\n    user_message: Message = {\"role\": \"user\", \"content\": \"What is 2+2?\"}\n\n    # Expected parsed message (what we want parse_response to produce)\n    thinking: list[ContentPart] = []\n    if \"DisableThinking\" not in renderer_cls.__name__:\n        thinking = [ThinkingPart(type=\"thinking\", thinking=\"Let me work through this.\")]\n    expected_content = thinking + [TextPart(type=\"text\", text=\"The answer is 42.\")]\n    expected_message: Message = {\"role\": \"assistant\", \"content\": expected_content}\n\n    # Render expected message to get full response tokens\n    rendered = renderer.render_message(\n        expected_message, RenderContext(idx=1, is_last=True, prev_message=user_message)\n    )\n    full_response_tokens = [t for chunk in rendered.output for t in chunk.tokens]\n\n    # Build prompt (may include prefill like <think>)\n    prompt = renderer.build_generation_prompt([user_message])\n    prompt_tokens = prompt.to_ints()\n\n    # Find prefill by matching end of prompt with start of rendered response\n    # This is renderer-agnostic: whatever prefill the renderer adds will be found\n    prefill_len = 0\n    for i in range(1, min(len(prompt_tokens), len(full_response_tokens)) + 1):\n        if prompt_tokens[-i:] == full_response_tokens[:i]:\n            prefill_len = i\n\n    # Simulate sampling: strip prefill\n    continuation_tokens = full_response_tokens[prefill_len:]\n\n    # Parse the continuation\n    parsed_message, _ = renderer.parse_response(continuation_tokens)\n\n    # Should recover the expected message\n    assert ensure_list(parsed_message[\"content\"]) == ensure_list(expected_message[\"content\"]), (\n        f\"Roundtrip failed: parsed_message != expected_message for {model_name} {renderer_cls.__name__}\"\n    )\n\n    # Roundtrip: full conversation should match prompt + continuation\n    full_convo = [user_message, parsed_message]\n    supervised, _ = renderer.build_supervised_example(full_convo)\n    assert supervised.to_ints() == prompt_tokens + continuation_tokens\n"
  },
  {
    "path": "tinker_cookbook/renderers/qwen3.py",
    "content": "\"\"\"\nQwen3 family renderers - text and vision-language models.\n\nIncludes:\n- Qwen3Renderer: Base Qwen3 with thinking enabled\n- Qwen3DisableThinkingRenderer: Qwen3 with thinking disabled\n- Qwen3InstructRenderer: Qwen3 instruct 2507 models (no <think> tag)\n- Qwen3VLRenderer: Vision-language Qwen3 with thinking\n- Qwen3VLInstructRenderer: Vision-language instruct models\n\"\"\"\n\nimport json\nfrom typing import cast\n\nimport tinker\n\nfrom tinker_cookbook.image_processing_utils import ImageProcessor\nfrom tinker_cookbook.renderers.base import (\n    ImagePart,\n    ImageProcessorProtocol,\n    Message,\n    RenderContext,\n    RenderedMessage,\n    Renderer,\n    TextPart,\n    ToolCall,\n    ToolSpec,\n    UnparsedToolCall,\n    _tool_call_payload,\n    image_to_chunk,\n    parse_content_blocks,\n    parse_response_for_stop_token,\n    remove_thinking,\n)\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n\ndef _merge_consecutive_text_parts(\n    chunks: list[ImagePart | TextPart],\n) -> list[ImagePart | TextPart]:\n    \"\"\"Merge consecutive TextParts into single parts.\n\n    This ensures text is tokenized as a single string, matching HuggingFace's\n    apply_chat_template behavior which tokenizes the full rendered string at once.\n    Without merging, tokenization boundaries between chunks can produce different\n    token sequences (though they decode to identical strings).\n    \"\"\"\n    if not chunks:\n        return chunks\n\n    merged: list[ImagePart | TextPart] = [chunks[0]]\n    for chunk in chunks[1:]:\n        if chunk[\"type\"] == \"text\" and merged[-1][\"type\"] == \"text\":\n            merged[-1] = TextPart(type=\"text\", text=merged[-1][\"text\"] + chunk[\"text\"])\n        else:\n            merged.append(chunk)\n    return merged\n\n\nclass Qwen3Renderer(Renderer):\n    \"\"\"\n    Renderer for Qwen3 models with thinking enabled.\n\n    This renderer is designed to match HuggingFace's Qwen3 chat template behavior\n    (with enable_thinking=True, which is the default). This ensures compatibility\n    with the OpenAI-compatible /chat/completions endpoint, which uses HF templates.\n\n    Reference: https://huggingface.co/Qwen/Qwen3-8B/blob/main/tokenizer_config.json\n\n    Format:\n        <|im_start|>system\n        You are Qwen, created by Alibaba Cloud.<|im_end|>\n        <|im_start|>user\n        What can you help me with?<|im_end|>\n        <|im_start|>assistant\n        <think>\n        [reasoning content]\n        </think>\n        I can help you with...<|im_end|>\n\n    The default strip_thinking_from_history=True matches HF behavior where thinking\n    blocks are stripped from historical assistant messages in multi-turn conversations.\n    Use strip_thinking_from_history=False for multi-turn RL to get the extension property.\n    \"\"\"\n\n    supports_streaming = True\n\n    def __init__(self, tokenizer: Tokenizer, strip_thinking_from_history: bool = True):\n        \"\"\"\n        Args:\n            tokenizer: The tokenizer to use for encoding.\n            strip_thinking_from_history: When True (default), strips <think>...</think> blocks\n                from assistant messages in multi-turn history. This matches HuggingFace's\n                Qwen3 chat template behavior. Set to False to preserve thinking in history\n                (useful for multi-turn RL where you need the extension property).\n\n        Note: When strip_thinking_from_history=True, this renderer produces identical\n        tokens to HuggingFace's apply_chat_template with enable_thinking=True.\n\n        See /rl/sequence-extension in the docs for details on how strip_thinking_from_history\n        affects multi-turn RL compute efficiency.\n        \"\"\"\n        super().__init__(tokenizer)\n        self.strip_thinking_from_history = strip_thinking_from_history\n\n    @property\n    def has_extension_property(self) -> bool:\n        \"\"\"Extension property depends on strip_thinking_from_history setting.\n\n        When strip_thinking_from_history=False, thinking blocks are preserved in\n        history, so each successive observation is a prefix extension of the previous.\n\n        When strip_thinking_from_history=True (default), thinking blocks are stripped\n        from historical messages, breaking the extension property.\n        \"\"\"\n        return not self.strip_thinking_from_history\n\n    def _get_qwen_role_for_message(self, message: Message) -> str:\n        \"\"\"Get the role to use for rendering a message in Qwen format.\n\n        Per HuggingFace Qwen3 chat template, tool messages are rendered with role \"user\".\n        \"\"\"\n        role = message[\"role\"]\n        if role == \"tool\":\n            return \"user\"\n        return role\n\n    def _wrap_qwen_tool_response(self, content: str) -> str:\n        \"\"\"Wrap tool response content in Qwen's <tool_response> tags.\"\"\"\n        return f\"<tool_response>\\n{content}\\n</tool_response>\"\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        maybe_newline = \"\\n\" if ctx.idx > 0 else \"\"\n\n        role = self._get_qwen_role_for_message(message)\n        header_str = f\"{maybe_newline}<|im_start|>{role}\\n\"\n\n        content = message[\"content\"]\n\n        if isinstance(content, list):\n            # Structured content - handle with list operations\n            parts = content\n            if (\n                self.strip_thinking_from_history\n                and message[\"role\"] == \"assistant\"\n                and not ctx.is_last\n            ):\n                # Remove thinking parts for historical messages\n                parts = remove_thinking(parts)\n            # Render parts in order, preserving interleaved thinking/text structure.\n            # No separator needed - whitespace is preserved in TextPart for roundtrip identity.\n            rendered_parts = []\n            for p in parts:\n                if p[\"type\"] == \"thinking\":\n                    rendered_parts.append(f\"<think>{p['thinking']}</think>\")\n                elif p[\"type\"] == \"text\":\n                    rendered_parts.append(p[\"text\"])\n            output_content = \"\".join(rendered_parts)\n        else:\n            # String content - pass through as-is.\n            # Note: strip_thinking_from_history only works with list-based content.\n            # For stripping to work on historical messages, use structured content\n            # with ThinkingPart separated from text (as returned by parse_response).\n            output_content = content\n\n        # Handle tool response wrapping\n        if message[\"role\"] == \"tool\":\n            output_content = self._wrap_qwen_tool_response(output_content)\n\n        # Handle tool_calls field\n        if \"tool_calls\" in message:\n            # Add leading newline to match HF template behavior\n            output_content += \"\\n\" + \"\\n\".join(\n                [\n                    f\"<tool_call>\\n{json.dumps(_tool_call_payload(tool_call))}\\n</tool_call>\"\n                    for tool_call in message[\"tool_calls\"]\n                ]\n            )\n        output_content += \"<|im_end|>\"\n        header = tinker.types.EncodedTextChunk(\n            tokens=self.tokenizer.encode(header_str, add_special_tokens=False)\n        )\n        output: list[tinker.ModelInputChunk] = [\n            tinker.types.EncodedTextChunk(\n                tokens=self.tokenizer.encode(output_content, add_special_tokens=False)\n            )\n        ]\n        return RenderedMessage(header=header, output=output)\n\n    @property\n    def _end_message_token(self) -> int:\n        tokens = self.tokenizer.encode(\"<|im_end|>\", add_special_tokens=False)\n        assert len(tokens) == 1, f\"Expected single token for <|im_end|>, got {len(tokens)}\"\n        return tokens[0]\n\n    def get_stop_sequences(self) -> list[int]:\n        return [self._end_message_token]\n\n    def parse_response(self, response: list[int]) -> tuple[Message, bool]:\n        response = self._normalize_response_tokens(response)\n        assistant_message, parse_success = parse_response_for_stop_token(\n            response, self.tokenizer, self._end_message_token\n        )\n        if not parse_success:\n            return assistant_message, False\n\n        # Parse <think>...</think> and <tool_call>...</tool_call> blocks together\n        # to preserve ordering. Tool calls use Qwen's format:\n        # - https://qwen.readthedocs.io/en/latest/getting_started/concepts.html#tool-calling\n        # - https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py#L279-L282\n        assert isinstance(assistant_message[\"content\"], str)\n        content = assistant_message[\"content\"]\n\n        # Parse all blocks in one pass, preserving order\n        result = parse_content_blocks(content)\n\n        if result is not None:\n            parts, tool_results = result\n            assistant_message[\"content\"] = parts\n\n            tool_calls = [t for t in tool_results if isinstance(t, ToolCall)]\n            unparsed = [t for t in tool_results if isinstance(t, UnparsedToolCall)]\n            if tool_calls:\n                assistant_message[\"tool_calls\"] = tool_calls\n            if unparsed:\n                assistant_message[\"unparsed_tool_calls\"] = unparsed\n        else:\n            # No special blocks found - keep as string for backward compatibility\n            assistant_message[\"content\"] = content\n\n        return assistant_message, True\n\n    def _parse_response_for_streaming(self, response: list[int]) -> tuple[Message, bool]:\n        \"\"\"Parse response for streaming, always applying full content parsing.\n\n        Unlike parse_response which short-circuits on missing stop token,\n        this always parses think blocks and tool calls from the content.\n        This ensures the final Message emitted by streaming is complete\n        even for truncated responses.\n\n        Note: _normalize_response_tokens is NOT called here because\n        parse_response_streaming already normalizes before feeding tokens\n        to the parser.\n        \"\"\"\n        assistant_message, parse_success = parse_response_for_stop_token(\n            response, self.tokenizer, self._end_message_token\n        )\n\n        assert isinstance(assistant_message[\"content\"], str)\n        content = assistant_message[\"content\"]\n\n        result = parse_content_blocks(content)\n\n        if result is not None:\n            parts, tool_results = result\n            assistant_message[\"content\"] = parts\n\n            tool_calls = [t for t in tool_results if isinstance(t, ToolCall)]\n            unparsed = [t for t in tool_results if isinstance(t, UnparsedToolCall)]\n            if tool_calls:\n                assistant_message[\"tool_calls\"] = tool_calls\n            if unparsed:\n                assistant_message[\"unparsed_tool_calls\"] = unparsed\n        else:\n            assistant_message[\"content\"] = content\n\n        return assistant_message, parse_success\n\n    def to_openai_message(self, message: Message) -> dict:\n        \"\"\"Convert a Message to OpenAI API format with reasoning_content for thinking.\n\n        Qwen3's HF template accepts either:\n        - message['reasoning_content'] as a separate field\n        - <think>...</think> embedded in content\n\n        We use reasoning_content for cleaner separation.\n        \"\"\"\n        result: dict = {\"role\": message[\"role\"]}\n\n        content = message[\"content\"]\n        if isinstance(content, str):\n            result[\"content\"] = content\n        else:\n            # Extract thinking into reasoning_content, keep text in content\n            thinking_parts = []\n            text_parts = []\n            for p in content:\n                if p[\"type\"] == \"thinking\":\n                    thinking_parts.append(p[\"thinking\"])\n                elif p[\"type\"] == \"text\":\n                    text_parts.append(p[\"text\"])\n                # Skip tool_call/unparsed_tool_call - handled via tool_calls field\n\n            result[\"content\"] = \"\".join(text_parts)\n            if thinking_parts:\n                result[\"reasoning_content\"] = \"\".join(thinking_parts)\n\n        # Handle tool_calls\n        if \"tool_calls\" in message and message[\"tool_calls\"]:  # noqa: RUF019\n            result[\"tool_calls\"] = [\n                {\n                    \"type\": \"function\",\n                    \"id\": tc.id,\n                    \"function\": {\n                        \"name\": tc.function.name,\n                        \"arguments\": self._to_openai_tool_arguments(tc.function.arguments),\n                    },\n                }\n                for tc in message[\"tool_calls\"]\n            ]\n\n        # Handle tool response fields\n        if message[\"role\"] == \"tool\":\n            if \"tool_call_id\" in message:\n                result[\"tool_call_id\"] = message[\"tool_call_id\"]\n            if \"name\" in message:\n                result[\"name\"] = message[\"name\"]\n\n        return result\n\n    def _to_openai_tool_arguments(self, arguments: str) -> str | dict:\n        \"\"\"Convert tool arguments for OpenAI-compatible message payloads.\n\n        Qwen3 templates accept JSON-string arguments directly; subclasses can\n        override to return dicts for templates that iterate over arguments.\n        \"\"\"\n        return arguments\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        \"\"\"Create system message with Qwen3 tool specifications.\n\n        Qwen3 uses XML `<tools>` tags containing JSON tool definitions in OpenAI format,\n        appended to the system message content.\n\n        References:\n        - https://qwen.readthedocs.io/en/latest/getting_started/concepts.html#tool-calling\n        - https://huggingface.co/Qwen/Qwen3-8B/blob/main/tokenizer_config.json\n        \"\"\"\n        tools_text = \"\"\n        if tools:\n            # Each tool is wrapped in {\"type\": \"function\", \"function\": {...}} per OpenAI format\n            # Use separators=(\", \", \": \") to match HF's tojson filter output\n            tool_lines = \"\\n\".join(\n                json.dumps(\n                    {\"type\": \"function\", \"function\": tool},\n                    separators=(\", \", \": \"),\n                )\n                for tool in tools\n            )\n            tools_text = f\"\"\"# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tool_lines}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{{\"name\": <function-name>, \"arguments\": <args-json-object>}}\n</tool_call>\"\"\"\n\n        # Add separator between system prompt and tools if system prompt exists\n        if system_prompt:\n            content = system_prompt + \"\\n\\n\" + tools_text\n        else:\n            content = tools_text\n\n        return [Message(role=\"system\", content=content)]\n\n\nclass Qwen3DisableThinkingRenderer(Qwen3Renderer):\n    \"\"\"\n    Renderer for Qwen3 hybrid models with thinking disabled.\n\n    This renderer matches HuggingFace's Qwen3 chat template behavior with\n    enable_thinking=False (or thinking=False for apply_chat_template). It adds\n    empty <think>\\\\n\\\\n</think>\\\\n\\\\n blocks to assistant messages, signaling to\n    the model that it should respond directly without extended reasoning.\n\n    Use this renderer when you want to train or sample from Qwen3 models in\n    \"non-thinking\" mode while maintaining compatibility with the OpenAI endpoint.\n    \"\"\"\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        # Get the base rendered message\n        rendered = super().render_message(message, ctx)\n\n        # Add empty thinking block to header for last assistant message\n        # This goes in header (weight=0) so observation matches generation prompt.\n        if message[\"role\"] == \"assistant\" and ctx.is_last:\n            content = message.get(\"content\", \"\")\n            if isinstance(content, str):\n                has_think = \"<think>\" in content\n            else:\n                has_think = any(p[\"type\"] == \"thinking\" for p in content)\n\n            if not has_think:\n                empty_think_tokens = self.tokenizer.encode(\n                    \"<think>\\n\\n</think>\\n\\n\", add_special_tokens=False\n                )\n                old_header_tokens = list(rendered.header.tokens) if rendered.header else []\n                new_header = tinker.EncodedTextChunk(tokens=old_header_tokens + empty_think_tokens)\n                rendered = RenderedMessage(\n                    header=new_header, output=rendered.output, stop_overlap=rendered.stop_overlap\n                )\n\n        return rendered\n\n\nclass Qwen3InstructRenderer(Qwen3Renderer):\n    \"\"\"\n    Renderer for Qwen3 instruct 2507 models. Unlike the earlier Qwen3 models, these models do not\n    use the <think> tag at all.\n\n    Inherits from Qwen3Renderer. ThinkingPart in content is still handled (rendered as\n    <think>...</think>) in case the conversation includes thinking.\n    \"\"\"\n\n    @property\n    def has_extension_property(self) -> bool:\n        \"\"\"Qwen3 Instruct always satisfies extension - no thinking to strip from history.\"\"\"\n        # NOTE: If callers include ThinkingPart in history, Qwen3Renderer may still strip it\n        # when strip_thinking_from_history=True, so extension can break.\n        # This is a rare case that'll only occur if we prompt the instruct model\n        # with a conversation from a different model.\n        return True\n\n\nclass Qwen3VLRenderer(Qwen3Renderer):\n    \"\"\"\n    Vision-language renderer for Qwen3-VL models with thinking support.\n\n    Format like this:\n        <|im_start|>system\n        You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n        <|im_start|>user\n        What can you help me with?<|im_end|>\n        <|im_start|>assistant\n        <think>\n\n        </think>\n        I can help you with...<|im_end|>\n\n    The default strip_thinking_from_history=True matches the non-VL Qwen3Renderer behavior.\n    \"\"\"\n\n    image_processor: ImageProcessor | None\n\n    def __init__(\n        self,\n        tokenizer: Tokenizer,\n        image_processor: ImageProcessor | None = None,\n        strip_thinking_from_history: bool = True,\n        merge_text_chunks: bool = True,\n    ):\n        self.tokenizer = tokenizer\n        self.image_processor = image_processor\n        self.strip_thinking_from_history = strip_thinking_from_history\n        self.merge_text_chunks = merge_text_chunks\n\n    def _format_thinking_text(self, thinking: str) -> str:\n        \"\"\"Format a ThinkingPart payload for rendering.\"\"\"\n        return f\"<think>{thinking}</think>\"\n\n    def _assistant_header_suffix(self, message: Message, ctx: RenderContext) -> str:\n        \"\"\"Additional assistant header text injected before content.\"\"\"\n        return \"\"\n\n    def _preprocess_message_parts(\n        self, message: Message, *, strip_thinking: bool = False\n    ) -> list[ImagePart | TextPart]:\n        \"\"\"Convert message content to list form for VL rendering.\n\n        Converts ThinkingPart to <think>...</think> text (or strips if strip_thinking=True).\n        Wraps images with vision tokens. Tool calls are in message's tool_calls field.\n        \"\"\"\n        content = message[\"content\"]\n        if isinstance(content, str):\n            base_parts: list[ImagePart | TextPart] = [TextPart(type=\"text\", text=content)]\n        else:\n            # Convert structured content to ImagePart/TextPart list\n            base_parts: list[ImagePart | TextPart] = []\n            for p in content:\n                if p[\"type\"] == \"text\":\n                    base_parts.append(cast(TextPart, p))\n                elif p[\"type\"] == \"image\":\n                    base_parts.append(cast(ImagePart, p))\n                elif p[\"type\"] == \"thinking\" and not strip_thinking:\n                    # Render thinking as <think>...</think> text\n                    base_parts.append(\n                        TextPart(type=\"text\", text=self._format_thinking_text(p[\"thinking\"]))\n                    )\n                    # else: strip thinking by not appending\n\n        # Wrap images with vision tokens\n        chunks: list[ImagePart | TextPart] = []\n        for content_chunk in base_parts:\n            if content_chunk[\"type\"] == \"image\":\n                chunks.append(TextPart(type=\"text\", text=\"<|vision_start|>\"))\n\n            chunks.append(content_chunk)\n\n            if content_chunk[\"type\"] == \"image\":\n                chunks.append(TextPart(type=\"text\", text=\"<|vision_end|>\"))\n\n        return chunks\n\n    def _wrap_qwen_tool_response_chunks(\n        self, chunks: list[ImagePart | TextPart]\n    ) -> list[ImagePart | TextPart]:\n        \"\"\"Wrap content chunks in Qwen's <tool_response> tags for multimodal messages.\"\"\"\n        return (\n            [TextPart(type=\"text\", text=\"<tool_response>\\n\")]\n            + chunks\n            + [TextPart(type=\"text\", text=\"\\n</tool_response>\")]\n        )\n\n    def _format_tool_calls_chunks(self, message: Message) -> list[ImagePart | TextPart]:\n        \"\"\"Format tool_calls as output chunks. Override in subclasses for different formats.\"\"\"\n        # Add leading newline to match HF template behavior\n        assert \"tool_calls\" in message, \"tool_calls are required to format tool calls\"\n        return [\n            TextPart(\n                type=\"text\",\n                text=\"\\n\"\n                + \"\\n\".join(\n                    [\n                        f\"<tool_call>\\n{json.dumps(_tool_call_payload(tool_call))}\\n</tool_call>\"\n                        for tool_call in message[\"tool_calls\"]\n                    ]\n                ),\n            )\n        ]\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        maybe_newline = \"\\n\" if ctx.idx > 0 else \"\"\n\n        role = self._get_qwen_role_for_message(message)\n        header_str = f\"{maybe_newline}<|im_start|>{role}\\n\"\n        if message[\"role\"] == \"assistant\":\n            header_str += self._assistant_header_suffix(message, ctx)\n\n        # Strip thinking from history for non-last assistant messages (matching non-VL behavior)\n        strip_thinking = (\n            self.strip_thinking_from_history and message[\"role\"] == \"assistant\" and not ctx.is_last\n        )\n        output_chunks = self._preprocess_message_parts(message, strip_thinking=strip_thinking)\n\n        # Handle tool response wrapping\n        if message[\"role\"] == \"tool\":\n            output_chunks = self._wrap_qwen_tool_response_chunks(output_chunks)\n\n        if \"tool_calls\" in message:\n            output_chunks += self._format_tool_calls_chunks(message)\n        output_chunks += [TextPart(type=\"text\", text=\"<|im_end|>\")]\n\n        if self.merge_text_chunks:\n            output_chunks = _merge_consecutive_text_parts(output_chunks)\n\n        output_chunks_encoded: list[tinker.ModelInputChunk] = []\n        for x in output_chunks:\n            if x[\"type\"] == \"image\":\n                assert self.image_processor is not None, (\n                    \"image_processor is required to render image content\"\n                )\n                output_chunks_encoded.append(\n                    image_to_chunk(\n                        image_or_str=x[\"image\"],\n                        image_processor=cast(ImageProcessorProtocol, self.image_processor),\n                    )\n                )\n            else:\n                output_chunks_encoded.append(\n                    tinker.EncodedTextChunk(\n                        tokens=self.tokenizer.encode(x[\"text\"], add_special_tokens=False)\n                    )\n                )\n\n        header = tinker.types.EncodedTextChunk(\n            tokens=self.tokenizer.encode(header_str, add_special_tokens=False)\n        )\n        return RenderedMessage(header=header, output=output_chunks_encoded)\n\n\nclass Qwen3VLInstructRenderer(Qwen3VLRenderer):\n    \"\"\"\n    Renderer for Qwen3-VL Instruct models.\n\n    Unlike the Qwen3-VL Thinking models, The Qwen3-VL Instruct models do not use the <think> tag.\n    \"\"\"\n\n    pass\n"
  },
  {
    "path": "tinker_cookbook/renderers/qwen3_5.py",
    "content": "\"\"\"\nQwen3.5 family renderer.\n\nQwen3.5 models are VL models with the same basic\nchat format as Qwen3-VL (im_start/im_end, thinking, vision tokens) but with a\ndifferent tool calling format.\n\nTool calling differences from Qwen3:\n- Qwen3: JSON format  {\"name\": ..., \"arguments\": ...}\n- Qwen3.5: XML format  <function=name><parameter=param>value</parameter></function>\n\nUnlike Qwen3, the Qwen3.5 HF template:\n- Always adds <think>...</think> blocks to assistant messages after the last user\n  message (empty if no reasoning content).\n- Always adds <think>\\\\n to the generation prompt.\n\nReference: https://huggingface.co/Qwen/Qwen3.5-4B/blob/main/tokenizer_config.json\n\"\"\"\n\nimport json\nimport re\n\nfrom tinker_cookbook.renderers.base import (\n    ImagePart,\n    Message,\n    RenderContext,\n    Role,\n    TextPart,\n    ToolCall,\n    ToolSpec,\n    UnparsedToolCall,\n)\nfrom tinker_cookbook.renderers.qwen3 import Qwen3VLRenderer\n\n_FUNCTION_BLOCK_RE = re.compile(\n    r\"^\\s*<tool_call>\\s*<function=(?P<name>[^>\\n]+)>\\s*(?P<body>.*?)\\s*</function>\\s*</tool_call>\\s*$\",\n    re.DOTALL,\n)\n_PARAM_BLOCK_RE = re.compile(\n    r\"<parameter=(?P<name>[^>\\n]+)>\\s*(?P<value>.*?)\\s*</parameter>\",\n    re.DOTALL,\n)\n\n\nclass Qwen3_5Renderer(Qwen3VLRenderer):\n    \"\"\"\n    Renderer for Qwen3.5 models.\n\n    Subclasses Qwen3VLRenderer since Qwen3.5 models are VL models sharing the same\n    basic chat format. Overrides tool calling to use Qwen3.5's XML parameter format.\n\n    The Qwen3.5 HF template adds empty <think> blocks to assistant messages after\n    the last user message. This is handled via ctx.last_user_index, which is\n    populated by the base build_generation_prompt/build_supervised_example.\n    \"\"\"\n\n    def _get_generation_suffix(self, role: Role, ctx: RenderContext) -> list[int]:\n        \"\"\"Override to produce the full generation suffix directly.\n\n        Builds the header tokens manually and appends <think>\\\\n. This matches\n        the Qwen3.5 template's add_generation_prompt behavior for thinking mode.\n        \"\"\"\n        maybe_newline = \"\\n\" if ctx.idx > 0 else \"\"\n        header_str = f\"{maybe_newline}<|im_start|>{role}\\n<think>\\n\"\n        return self.tokenizer.encode(header_str, add_special_tokens=False)\n\n    def _assistant_header_suffix(self, message: Message, ctx: RenderContext) -> str:\n        \"\"\"Insert empty think block for assistant messages after the last user query.\"\"\"\n        if ctx.idx <= ctx.last_user_index:\n            return \"\"\n\n        content = message.get(\"content\", \"\")\n        has_think = False\n        if isinstance(content, list):\n            has_think = any(p[\"type\"] == \"thinking\" for p in content)\n        elif isinstance(content, str):\n            has_think = \"<think>\" in content\n\n        return \"\" if has_think else \"<think>\\n\\n</think>\\n\\n\"\n\n    def _format_thinking_text(self, thinking: str) -> str:\n        \"\"\"Qwen3.5 uses newline-padded think blocks.\"\"\"\n        return f\"<think>\\n{thinking}\\n</think>\\n\\n\"\n\n    def _to_openai_tool_arguments(self, arguments: str) -> str | dict:\n        \"\"\"Qwen3.5 chat template expects arguments as a mapping for |items.\"\"\"\n        return json.loads(arguments)\n\n    def _parse_qwen3_5_tool_call_xml(self, raw_text: str) -> ToolCall | UnparsedToolCall:\n        \"\"\"Parse Qwen3.5 XML-style tool calls from a raw <tool_call> block.\"\"\"\n        match = _FUNCTION_BLOCK_RE.match(raw_text)\n        if not match:\n            return UnparsedToolCall(raw_text=raw_text, error=\"Malformed Qwen3.5 tool call XML\")\n\n        function_name = match.group(\"name\").strip()\n        body = match.group(\"body\")\n        if not function_name:\n            return UnparsedToolCall(raw_text=raw_text, error=\"Missing function name\")\n\n        arguments: dict[str, object] = {}\n        pos = 0\n        for param in _PARAM_BLOCK_RE.finditer(body):\n            if body[pos : param.start()].strip():\n                return UnparsedToolCall(\n                    raw_text=raw_text,\n                    error=\"Unexpected non-parameter content inside <function> block\",\n                )\n\n            param_name = param.group(\"name\").strip()\n            param_value_text = param.group(\"value\").strip(\"\\n\")\n            if not param_name:\n                return UnparsedToolCall(raw_text=raw_text, error=\"Empty parameter name\")\n\n            try:\n                param_value: object = json.loads(param_value_text)\n            except json.JSONDecodeError:\n                param_value = param_value_text\n\n            arguments[param_name] = param_value\n            pos = param.end()\n\n        if body[pos:].strip():\n            return UnparsedToolCall(\n                raw_text=raw_text,\n                error=\"Unexpected trailing content inside <function> block\",\n            )\n\n        return ToolCall(\n            function=ToolCall.FunctionBody(\n                name=function_name,\n                arguments=json.dumps(arguments),\n            )\n        )\n\n    def _normalize_response_tokens(self, response: list[int]) -> list[int]:\n        \"\"\"Restore the prefilled <think>\\\\n before parsing sampled tokens.\n\n        Qwen3.5's generation suffix includes <think>\\\\n, so sampled tokens start\n        after that prefix. If the response contains </think> but doesn't start\n        with <think>\\\\n, we prepend it so the parser sees a complete think block.\n        \"\"\"\n        think_prefix_tokens = self.tokenizer.encode(\"<think>\\n\", add_special_tokens=False)\n        think_suffix_token = self.tokenizer.encode(\"</think>\", add_special_tokens=False)\n        assert len(think_suffix_token) == 1\n\n        starts_with_think = (\n            len(response) >= len(think_prefix_tokens)\n            and response[: len(think_prefix_tokens)] == think_prefix_tokens\n        )\n        if not starts_with_think and think_suffix_token[0] in response:\n            return think_prefix_tokens + response\n        return response\n\n    def _postprocess_parsed_message(self, message: Message) -> None:\n        \"\"\"Apply Qwen3.5-specific post-processing to a parsed message in-place.\n\n        1. Strips whitespace from thinking content (matches HF template |trim).\n        2. Removes the two separator newlines between </think> and text.\n        3. Converts Qwen3.5 XML tool calls from the parent's unparsed_tool_calls.\n        \"\"\"\n        content = message.get(\"content\")\n        if isinstance(content, list):\n            first_text_after_thinking: TextPart | None = None\n            seen_thinking = False\n            for p in content:\n                if p[\"type\"] == \"thinking\":\n                    p[\"thinking\"] = p[\"thinking\"].strip()\n                    seen_thinking = True\n                elif seen_thinking and p[\"type\"] == \"text\":\n                    first_text_after_thinking = p\n                    break\n\n            # Template inserts exactly two separator newlines between </think> and text.\n            if first_text_after_thinking is not None and first_text_after_thinking[\n                \"text\"\n            ].startswith(\"\\n\\n\"):\n                first_text_after_thinking[\"text\"] = first_text_after_thinking[\"text\"][2:]\n\n        # Qwen3 parent parser assumes JSON inside <tool_call>; convert XML blocks here.\n        converted_xml_calls: list[ToolCall] = []\n        remaining_unparsed: list[UnparsedToolCall] = []\n        for unparsed in message.get(\"unparsed_tool_calls\", []):\n            if \"<function=\" not in unparsed.raw_text:\n                remaining_unparsed.append(unparsed)\n                continue\n            parsed = self._parse_qwen3_5_tool_call_xml(unparsed.raw_text)\n            if isinstance(parsed, ToolCall):\n                converted_xml_calls.append(parsed)\n            else:\n                remaining_unparsed.append(parsed)\n\n        if converted_xml_calls:\n            message[\"tool_calls\"] = message.get(\"tool_calls\", []) + converted_xml_calls\n        if remaining_unparsed:\n            message[\"unparsed_tool_calls\"] = remaining_unparsed\n        else:\n            message.pop(\"unparsed_tool_calls\", None)\n\n    def parse_response(self, response: list[int]) -> tuple[Message, bool]:\n        \"\"\"Parse response with Qwen3.5-specific post-processing.\"\"\"\n        message, success = super().parse_response(response)\n        if not success:\n            return message, success\n\n        self._postprocess_parsed_message(message)\n        return message, success\n\n    def _parse_response_for_streaming(self, response: list[int]) -> tuple[Message, bool]:\n        \"\"\"Parse response for streaming with Qwen3.5-specific post-processing.\"\"\"\n        message, parse_success = super()._parse_response_for_streaming(response)\n        self._postprocess_parsed_message(message)\n        return message, parse_success\n\n    def _format_tool_call_xml(self, tool_call: ToolCall) -> str:\n        \"\"\"Format a single tool call in Qwen3.5's XML parameter format.\"\"\"\n        args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}\n        lines = [f\"<tool_call>\\n<function={tool_call.function.name}>\"]\n        for param_name, param_value in args.items():\n            if isinstance(param_value, (dict, list)):\n                value_str = json.dumps(param_value)\n            else:\n                value_str = str(param_value)\n            lines.append(f\"<parameter={param_name}>\\n{value_str}\\n</parameter>\")\n        lines.append(\"</function>\\n</tool_call>\")\n        return \"\\n\".join(lines)\n\n    def _format_tool_calls_chunks(self, message: Message) -> list[ImagePart | TextPart]:\n        \"\"\"Format tool_calls using Qwen3.5's XML parameter format.\"\"\"\n        assert \"tool_calls\" in message, \"tool_calls are required to format tool calls\"\n        return [\n            TextPart(\n                type=\"text\",\n                text=\"\\n\\n\"\n                + \"\\n\".join(self._format_tool_call_xml(tc) for tc in message[\"tool_calls\"]),\n            )\n        ]\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        \"\"\"Create system message with Qwen3.5 tool specifications.\n\n        Qwen3.5 uses a different tool declaration format from Qwen3, with XML-based\n        function/parameter calling syntax.\n\n        Reference: https://huggingface.co/Qwen/Qwen3.5-4B/blob/main/tokenizer_config.json\n        \"\"\"\n        tools_text = \"\"\n        if tools:\n            tool_lines = \"\\n\".join(json.dumps(tool) for tool in tools)\n            tools_text = (\n                \"# Tools\\n\\n\"\n                \"You have access to the following functions:\\n\\n\"\n                \"<tools>\\n\"\n                f\"{tool_lines}\\n\"\n                \"</tools>\\n\\n\"\n                \"If you choose to call a function ONLY reply in the following format with NO suffix:\\n\\n\"\n                \"<tool_call>\\n\"\n                \"<function=example_function_name>\\n\"\n                \"<parameter=example_parameter_1>\\n\"\n                \"value_1\\n\"\n                \"</parameter>\\n\"\n                \"<parameter=example_parameter_2>\\n\"\n                \"This is the value for the second parameter\\n\"\n                \"that can span\\n\"\n                \"multiple lines\\n\"\n                \"</parameter>\\n\"\n                \"</function>\\n\"\n                \"</tool_call>\\n\\n\"\n                \"<IMPORTANT>\\n\"\n                \"Reminder:\\n\"\n                \"- Function calls MUST follow the specified format: \"\n                \"an inner <function=...></function> block must be nested within \"\n                \"<tool_call></tool_call> XML tags\\n\"\n                \"- Required parameters MUST be specified\\n\"\n                \"- You may provide optional reasoning for your function call in natural language \"\n                \"BEFORE the function call, but NOT after\\n\"\n                \"- If there is no function call available, answer the question like normal with \"\n                \"your current knowledge and do not tell the user about function calls\\n\"\n                \"</IMPORTANT>\"\n            )\n\n        if tools_text:\n            content = tools_text + \"\\n\\n\" + system_prompt if system_prompt else tools_text\n        else:\n            content = system_prompt\n\n        return [Message(role=\"system\", content=content)]\n\n\nclass Qwen3_5DisableThinkingRenderer(Qwen3_5Renderer):\n    \"\"\"\n    Renderer for Qwen3.5 models with thinking disabled.\n\n    Matches the Qwen3.5 HF template with enable_thinking=False. The only difference\n    from Qwen3_5Renderer is the generation suffix: <think>\\\\n\\\\n</think>\\\\n\\\\n instead\n    of <think>\\\\n, signaling to the model to respond directly without reasoning.\n    \"\"\"\n\n    def _get_generation_suffix(self, role: Role, ctx: RenderContext) -> list[int]:\n        maybe_newline = \"\\n\" if ctx.idx > 0 else \"\"\n        header_str = f\"{maybe_newline}<|im_start|>{role}\\n<think>\\n\\n</think>\\n\\n\"\n        return self.tokenizer.encode(header_str, add_special_tokens=False)\n"
  },
  {
    "path": "tinker_cookbook/renderers/qwen3_test.py",
    "content": "\"\"\"Tests specific to Qwen3 renderers (parse_response, streaming, disable-thinking behavior).\n\nAlso covers Qwen3.5 response normalization (prefilled <think> tag restoration) for\nboth batch and streaming paths.\n\"\"\"\n\nfrom typing import TypeGuard, cast\n\nimport pytest\nfrom transformers.models.auto.tokenization_auto import AutoTokenizer\n\nfrom tinker_cookbook.renderers import (\n    Message,\n    StreamingMessageHeader,\n    StreamingTextDelta,\n    StreamingThinkingDelta,\n    TextPart,\n    ThinkingPart,\n    get_renderer,\n)\nfrom tinker_cookbook.renderers.base import ensure_list\nfrom tinker_cookbook.renderers.qwen3 import Qwen3Renderer\nfrom tinker_cookbook.renderers.qwen3_5 import Qwen3_5Renderer\nfrom tinker_cookbook.renderers.testing_utils import extract_token_ids\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\ndef _is_message(obj) -> TypeGuard[Message]:\n    \"\"\"Check if object is a Message dict (TypedDict doesn't support isinstance).\"\"\"\n    return isinstance(obj, dict) and \"role\" in obj and \"content\" in obj\n\n\n# =============================================================================\n# Qwen3 parse_response Tests\n# =============================================================================\n\n\ndef test_qwen3_parse_response_extracts_thinking():\n    \"\"\"Test Qwen3Renderer.parse_response extracts thinking to ThinkingPart.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"<think>Let me reason about this.</think>The answer is 42.<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    assert message[\"role\"] == \"assistant\"\n\n    content = message[\"content\"]\n    assert isinstance(content, list)\n\n    thinking_parts = [p for p in content if p[\"type\"] == \"thinking\"]\n    text_parts = [p for p in content if p[\"type\"] == \"text\"]\n\n    assert len(thinking_parts) == 1\n    assert thinking_parts[0][\"thinking\"] == \"Let me reason about this.\"\n\n    assert len(text_parts) == 1\n    assert text_parts[0][\"text\"] == \"The answer is 42.\"\n\n\ndef test_qwen3_parse_response_multiple_think_blocks():\n    \"\"\"Test Qwen3Renderer.parse_response handles multiple interleaved think blocks.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"<think>step 1</think>partial answer<think>step 2</think>final answer<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n    assert len(content) == 4\n\n    assert content[0] == ThinkingPart(type=\"thinking\", thinking=\"step 1\")\n    assert content[1] == TextPart(type=\"text\", text=\"partial answer\")\n    assert content[2] == ThinkingPart(type=\"thinking\", thinking=\"step 2\")\n    assert content[3] == TextPart(type=\"text\", text=\"final answer\")\n\n\ndef test_qwen3_parse_response_no_thinking_returns_string():\n    \"\"\"Test Qwen3Renderer.parse_response returns string when no thinking.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"Just a plain response without thinking.<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    # Content should remain a string for backward compatibility\n    assert isinstance(message[\"content\"], str)\n    assert message[\"content\"] == \"Just a plain response without thinking.\"\n\n\ndef test_qwen3_parse_response_with_tool_calls():\n    \"\"\"Test Qwen3Renderer.parse_response puts tool calls in message['tool_calls'], not content.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = '<think>Let me search</think>I will search for that.<tool_call>{\"name\": \"web_search\", \"arguments\": {\"query\": \"weather\"}}</tool_call><|im_end|>'\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n\n    # Content should only have ThinkingPart and TextPart — no tool calls\n    assert len(content) == 2\n    assert content[0][\"type\"] == \"thinking\"\n    assert content[0][\"thinking\"] == \"Let me search\"\n    assert content[1][\"type\"] == \"text\"\n    assert content[1][\"text\"] == \"I will search for that.\"\n\n    # Tool calls live exclusively in message[\"tool_calls\"]\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 1\n    assert message[\"tool_calls\"][0].function.name == \"web_search\"\n\n\ndef test_qwen3_parse_response_tool_call_only():\n    \"\"\"Test Qwen3Renderer.parse_response with only a tool call.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = (\n        '<tool_call>{\"name\": \"calculator\", \"arguments\": {\"expr\": \"2+2\"}}</tool_call><|im_end|>'\n    )\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success\n    content = message[\"content\"]\n    assert isinstance(content, list)\n    # Content should be empty — only a tool call, no text or thinking\n    assert len(content) == 0\n\n    # Tool call lives in message[\"tool_calls\"]\n    assert \"tool_calls\" in message and len(message[\"tool_calls\"]) == 1\n    assert message[\"tool_calls\"][0].function.name == \"calculator\"\n\n\n# =============================================================================\n# Qwen3 Disable-Thinking Tests\n# =============================================================================\n\n\ndef _get_basic_2turn() -> list[Message]:\n    return [\n        {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        {\"role\": \"assistant\", \"content\": \"I'm fine, thank you!\"},\n    ]\n\n\ndef _get_basic_3turn() -> list[Message]:\n    return [\n        {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        {\"role\": \"assistant\", \"content\": \"I'm fine, thank you!\"},\n        {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n    ]\n\n\ndef _get_basic_4turn() -> list[Message]:\n    return [\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\"role\": \"assistant\", \"content\": \"The answer is 4.\"},\n        {\"role\": \"user\", \"content\": \"And what is 3+3?\"},\n        {\"role\": \"assistant\", \"content\": \"The answer is 6.\"},\n    ]\n\n\ndef test_qwen3_disable_thinking_supervised():\n    \"\"\"\n    Test that Qwen3DisableThinkingRenderer adds the correct empty thinking block\n    to assistant messages for SFT, matching HF tokenizer with thinking=False.\n    \"\"\"\n    model_name = \"Qwen/Qwen3-8B\"\n    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n    renderer = get_renderer(\"qwen3_disable_thinking\", tokenizer)\n\n    messages = _get_basic_2turn()\n\n    model_input, _ = renderer.build_supervised_example(messages)\n    tinker_tokens = model_input.to_ints()\n    tinker_decoded = tokenizer.decode(tinker_tokens)\n\n    # Get expected format from official Qwen3 tokenizer with thinking=False\n    hf_decoded = tokenizer.apply_chat_template(\n        cast(list[dict[str, str]], messages), tokenize=False, thinking=False\n    )\n\n    # Verify the complete empty thinking block is present\n    assert \"<think>\\n\\n</think>\\n\\n\" in tinker_decoded, (\n        f\"Renderer must add '<think>\\\\n\\\\n</think>\\\\n\\\\n' but got: {tinker_decoded}\"\n    )\n\n    # Verify matches HF\n    assert tinker_decoded == hf_decoded.rstrip(\"\\n\"), (\n        f\"Tinker and HuggingFace outputs differ:\\n\"\n        f\"TINKER:\\n{tinker_decoded!r}\\n\\n\"\n        f\"HUGGINGFACE:\\n{hf_decoded!r}\"\n    )\n\n\ndef test_qwen3_disable_thinking_generation():\n    \"\"\"Test Qwen3DisableThinkingRenderer generation matches HF with enable_thinking=False.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-8B\")\n    cookbook_renderer = get_renderer(\"qwen3_disable_thinking\", tokenizer)\n\n    convo = _get_basic_3turn()\n\n    cookbook_tokens = cookbook_renderer.build_generation_prompt(convo).to_ints()\n    hf_tokens = tokenizer.apply_chat_template(\n        cast(list[dict[str, str]], convo),\n        add_generation_prompt=True,\n        tokenize=True,\n        enable_thinking=False,\n    )\n\n    hf_tokens_list = extract_token_ids(hf_tokens)\n\n    assert cookbook_tokens == hf_tokens_list, (\n        f\"Cookbook tokens: {cookbook_tokens}\\n\"\n        f\"Cookbook string: {tokenizer.decode(cookbook_tokens)}\\n\"\n        f\"HF tokens: {hf_tokens_list}\\n\"\n        f\"HF string: {tokenizer.decode(hf_tokens_list)}\"\n    )\n\n\ndef test_qwen3_disable_thinking_4turn():\n    \"\"\"\n    Test Qwen3DisableThinkingRenderer with 4-turn conversation.\n    Only the last assistant message should have the empty thinking block\n    (historical thinking is stripped, matching HF behavior).\n    \"\"\"\n    model_name = \"Qwen/Qwen3-8B\"\n    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n    renderer = get_renderer(\"qwen3_disable_thinking\", tokenizer)\n\n    messages = _get_basic_4turn()\n\n    model_input, _ = renderer.build_supervised_example(messages)\n    tinker_tokens = model_input.to_ints()\n    tinker_decoded = tokenizer.decode(tinker_tokens)\n\n    # Get expected format from HF\n    hf_decoded = tokenizer.apply_chat_template(\n        cast(list[dict[str, str]], messages), tokenize=False, thinking=False\n    )\n\n    assert tinker_decoded == hf_decoded.rstrip(\"\\n\"), (\n        f\"Tinker and HuggingFace outputs differ:\\n\"\n        f\"TINKER:\\n{tinker_decoded!r}\\n\\n\"\n        f\"HUGGINGFACE:\\n{hf_decoded!r}\"\n    )\n\n\n# =============================================================================\n# Qwen3 Streaming Parsing Tests\n# =============================================================================\n\n\ndef test_qwen3_streaming_simple_text():\n    \"\"\"Test streaming parsing of simple text response without thinking.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"Hello, world!<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert deltas[0].role == \"assistant\"\n\n    assert _is_message(deltas[-1])\n    assert deltas[-1][\"role\"] == \"assistant\"\n\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n    assert \"Hello, world!\" in text_content\n\n\ndef test_qwen3_streaming_with_thinking():\n    \"\"\"Test streaming parsing with thinking blocks.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"<think>Let me reason about this.</think>The answer is 42.<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert deltas[0].role == \"assistant\"\n\n    thinking_content = \"\".join(d.thinking for d in deltas if isinstance(d, StreamingThinkingDelta))\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n\n    assert \"Let me reason about this.\" in thinking_content\n    assert \"The answer is 42.\" in text_content\n\n    final_message = deltas[-1]\n    assert _is_message(final_message)\n\n\ndef test_qwen3_streaming_matches_batch():\n    \"\"\"Test that streaming parse produces same final message as batch parse.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"<think>Step 1: Analyze.\\nStep 2: Compute.</think>The result is 123.<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    batch_message, batch_success = renderer.parse_response(response_tokens)\n    assert batch_success\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n    streaming_message = deltas[-1]\n\n    assert _is_message(streaming_message)\n    assert streaming_message[\"role\"] == batch_message[\"role\"]\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n\n\ndef test_qwen3_streaming_content_index_increments():\n    \"\"\"Test that content_index increments when switching content types.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"<think>thinking</think>text<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    thinking_indices = [d.content_index for d in deltas if isinstance(d, StreamingThinkingDelta)]\n    text_indices = [d.content_index for d in deltas if isinstance(d, StreamingTextDelta)]\n\n    if thinking_indices and text_indices:\n        assert max(text_indices) > min(thinking_indices)\n\n\ndef test_qwen3_streaming_empty_response():\n    \"\"\"Test streaming parsing of empty/minimal response.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert _is_message(deltas[-1])\n\n\ndef test_qwen3_streaming_multiple_think_blocks():\n    \"\"\"Test streaming with multiple interleaved think blocks.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"<think>first thought</think>partial<think>second thought</think>final<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    batch_message, _ = renderer.parse_response(response_tokens)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    thinking_content = \"\".join(d.thinking for d in deltas if isinstance(d, StreamingThinkingDelta))\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n\n    assert \"first thought\" in thinking_content\n    assert \"second thought\" in thinking_content\n    assert \"partial\" in text_content\n    assert \"final\" in text_content\n\n    streaming_message = deltas[-1]\n    assert _is_message(streaming_message)\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n\n\ndef test_qwen3_streaming_no_unnecessary_buffering():\n    \"\"\"Test that we don't buffer more than necessary when no tag prefix matches.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"Hello world<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n    assert text_content == \"Hello world\"\n\n\ndef test_qwen3_streaming_with_emoji():\n    \"\"\"Test that streaming parser handles multi-byte UTF-8 (emoji) correctly.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = Qwen3Renderer(tokenizer)\n\n    response_str = \"<think>Let me think 🤔</think>Here's a party 🎉!<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    thinking_content = \"\".join(d.thinking for d in deltas if isinstance(d, StreamingThinkingDelta))\n    text_content = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n\n    assert \"�\" not in thinking_content, f\"Thinking has replacement chars: {thinking_content!r}\"\n    assert \"�\" not in text_content, f\"Text has replacement chars: {text_content!r}\"\n\n    assert \"🤔\" in thinking_content\n    assert \"🎉\" in text_content\n\n\n@pytest.mark.parametrize(\n    \"renderer_name\",\n    [\"qwen3\", \"qwen3_disable_thinking\", \"qwen3_instruct\"],\n)\ndef test_qwen3_streaming_supported_by_text_variants(renderer_name):\n    \"\"\"All text-only Qwen3 renderer variants support streaming.\"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    response_str = \"<think>reasoning</think>answer<|im_end|>\"\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert _is_message(deltas[-1])\n\n\n# =============================================================================\n# Qwen3 Streaming vs Batch Equivalence Tests\n# =============================================================================\n\n\ndef _assert_streaming_matches_batch(renderer, response_str: str):\n    \"\"\"Helper: verify streaming and batch parsing produce identical results.\"\"\"\n    tokenizer = renderer.tokenizer\n    response_tokens = tokenizer.encode(response_str, add_special_tokens=False)\n\n    batch_message, batch_success = renderer.parse_response(response_tokens)\n    deltas = list(renderer.parse_response_streaming(response_tokens))\n\n    assert len(deltas) >= 2, \"Should have at least header + final message\"\n    assert isinstance(deltas[0], StreamingMessageHeader)\n    assert _is_message(deltas[-1])\n\n    streaming_message = deltas[-1]\n    assert streaming_message[\"role\"] == batch_message[\"role\"]\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n    assert streaming_message.get(\"tool_calls\") == batch_message.get(\"tool_calls\")\n    assert streaming_message.get(\"unparsed_tool_calls\") == batch_message.get(\"unparsed_tool_calls\")\n\n    # Verify streamed deltas reconstruct the content\n    thinking_from_deltas = \"\".join(\n        d.thinking for d in deltas if isinstance(d, StreamingThinkingDelta)\n    )\n    text_from_deltas = \"\".join(d.text for d in deltas if isinstance(d, StreamingTextDelta))\n\n    batch_content = batch_message[\"content\"]\n    if isinstance(batch_content, list):\n        expected_thinking = \"\".join(p[\"thinking\"] for p in batch_content if p[\"type\"] == \"thinking\")\n        expected_text = \"\".join(p[\"text\"] for p in batch_content if p[\"type\"] == \"text\")\n    else:\n        expected_thinking = \"\"\n        expected_text = batch_content\n\n    assert thinking_from_deltas == expected_thinking\n    # Text deltas may include tool call markup before final parsing strips it\n    if not batch_message.get(\"tool_calls\") and not batch_message.get(\"unparsed_tool_calls\"):\n        assert text_from_deltas == expected_text\n\n    return deltas, batch_message\n\n\nclass TestQwen3StreamingBatchEquivalence:\n    \"\"\"Verify parse_response_streaming matches parse_response for all patterns.\"\"\"\n\n    @pytest.fixture\n    def renderer(self):\n        tokenizer = get_tokenizer(\"Qwen/Qwen3-30B-A3B\")\n        return Qwen3Renderer(tokenizer)\n\n    def test_simple_text(self, renderer):\n        _assert_streaming_matches_batch(renderer, \"Hello, world!<|im_end|>\")\n\n    def test_thinking_then_text(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>Let me reason step by step.\\n1. First...\\n2. Then...</think>\"\n            \"The answer is 42.<|im_end|>\",\n        )\n\n    def test_empty_thinking(self, renderer):\n        _assert_streaming_matches_batch(renderer, \"<think></think>Direct answer.<|im_end|>\")\n\n    def test_long_thinking(self, renderer):\n        thinking = (\n            \"First, let me understand the problem.\\n\\n\"\n            \"Key concepts:\\n1. Superposition\\n2. Measurement\\n3. Non-locality\\n\\n\"\n            \"I should explain this clearly.\"\n        )\n        _assert_streaming_matches_batch(\n            renderer, f\"<think>{thinking}</think>Quantum entanglement links particles.<|im_end|>\"\n        )\n\n    def test_multiple_think_blocks(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>first thought</think>partial<think>second thought</think>final<|im_end|>\",\n        )\n\n    def test_empty_response(self, renderer):\n        _assert_streaming_matches_batch(renderer, \"<|im_end|>\")\n\n    def test_whitespace_only(self, renderer):\n        _assert_streaming_matches_batch(renderer, \"   \\n\\t  <|im_end|>\")\n\n    def test_special_characters(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>x² + y² = r²</think>Special chars: <>&\\\"'`~!@#$%^&*()<|im_end|>\",\n        )\n\n    def test_emoji(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer, \"<think>🤔 thinking 💭</think>Answer 🎉✨!<|im_end|>\"\n        )\n\n    def test_code_blocks(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>Need a function.</think>\"\n            \"```python\\ndef hello():\\n    print('world')\\n```<|im_end|>\",\n        )\n\n    def test_html_like_content(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>HTML example</think><div><p>Hello</p></div><|im_end|>\",\n        )\n\n    def test_tool_call_with_thinking(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>I need to search.</think>I will search.\"\n            '<tool_call>\\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"weather\"}}\\n</tool_call>'\n            \"<|im_end|>\",\n        )\n\n    def test_tool_call_without_thinking(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            '<tool_call>\\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"SF\"}}\\n</tool_call>'\n            \"<|im_end|>\",\n        )\n\n    def test_multiline_thinking(self, renderer):\n        _assert_streaming_matches_batch(\n            renderer,\n            \"<think>\\nStep 1\\n\\nStep 2\\n\\nStep 3\\n</think>\\nResult.\\n<|im_end|>\",\n        )\n\n    def test_no_end_token(self, renderer):\n        \"\"\"Truncated response — streaming should still parse think blocks.\"\"\"\n        tokenizer = renderer.tokenizer\n        response_tokens = tokenizer.encode(\n            \"<think>reasoning</think>partial\", add_special_tokens=False\n        )\n\n        deltas = list(renderer.parse_response_streaming(response_tokens))\n        final = deltas[-1]\n        assert _is_message(final)\n        content = final[\"content\"]\n        assert isinstance(content, list), \"Truncated response should still parse think blocks\"\n        thinking = [p for p in content if p[\"type\"] == \"thinking\"]\n        text = [p for p in content if p[\"type\"] == \"text\"]\n        assert len(thinking) == 1 and thinking[0][\"thinking\"] == \"reasoning\"\n        assert len(text) == 1 and text[0][\"text\"] == \"partial\"\n\n    def test_content_index_ordering(self, renderer):\n        \"\"\"Content index strictly increases across type transitions.\"\"\"\n        response_tokens = renderer.tokenizer.encode(\n            \"<think>t1</think>x1<think>t2</think>x2<|im_end|>\", add_special_tokens=False\n        )\n        deltas = list(renderer.parse_response_streaming(response_tokens))\n\n        indexed = []\n        for d in deltas:\n            if isinstance(d, StreamingThinkingDelta):\n                indexed.append((\"thinking\", d.content_index))\n            elif isinstance(d, StreamingTextDelta):\n                indexed.append((\"text\", d.content_index))\n\n        indices = [idx for _, idx in indexed]\n        assert indices == sorted(indices), f\"Not monotonic: {indexed}\"\n        for i in range(1, len(indexed)):\n            if indexed[i][0] != indexed[i - 1][0]:\n                assert indexed[i][1] > indexed[i - 1][1]\n\n\n# =============================================================================\n# Qwen3.5 Prefill Normalization Tests\n#\n# Qwen3.5's generation suffix includes <think>\\n, so sampled tokens don't\n# include the opening <think>\\n. Both parse_response and parse_response_streaming\n# must restore it via _normalize_response_tokens.\n# =============================================================================\n\n\n@pytest.fixture\ndef qwen3_5_tokenizer():\n    return get_tokenizer(\"Qwen/Qwen3.5-35B-A3B\")\n\n\n@pytest.fixture\ndef qwen3_5_renderer(qwen3_5_tokenizer):\n    return Qwen3_5Renderer(qwen3_5_tokenizer)\n\n\ndef test_qwen3_5_parse_response_restores_prefilled_think_tag(qwen3_5_tokenizer, qwen3_5_renderer):\n    \"\"\"parse_response should restore <think>\\\\n when it was prefilled by generation prompt.\"\"\"\n    # Simulate sampled tokens after <think>\\n prefill: \"reasoning\\n</think>\\n\\nanswer<|im_end|>\"\n    response_tokens = qwen3_5_tokenizer.encode(\n        \"reasoning\\n</think>\\n\\nanswer<|im_end|>\",\n        add_special_tokens=False,\n    )\n\n    parsed_message, parse_success = qwen3_5_renderer.parse_response(response_tokens)\n\n    assert parse_success is True\n    assert isinstance(parsed_message[\"content\"], list)\n    assert parsed_message[\"content\"] == [\n        ThinkingPart(type=\"thinking\", thinking=\"reasoning\"),\n        TextPart(type=\"text\", text=\"answer\"),\n    ]\n\n\ndef test_qwen3_5_parse_response_streaming_restores_prefilled_think_tag(\n    qwen3_5_tokenizer, qwen3_5_renderer\n):\n    \"\"\"parse_response_streaming should restore <think>\\\\n when it was prefilled.\"\"\"\n    response_tokens = qwen3_5_tokenizer.encode(\n        \"reasoning\\n</think>\\n\\nanswer<|im_end|>\",\n        add_special_tokens=False,\n    )\n\n    deltas = list(qwen3_5_renderer.parse_response_streaming(response_tokens))\n    thinking_text = \"\".join(\n        delta.thinking for delta in deltas if isinstance(delta, StreamingThinkingDelta)\n    )\n    output_text = \"\".join(delta.text for delta in deltas if isinstance(delta, StreamingTextDelta))\n    final_message = cast(Message, deltas[-1])\n\n    assert \"reasoning\" in thinking_text\n    assert \"answer\" in output_text\n    assert _is_message(final_message)\n    assert isinstance(final_message[\"content\"], list)\n    assert final_message[\"content\"] == [\n        ThinkingPart(type=\"thinking\", thinking=\"reasoning\"),\n        TextPart(type=\"text\", text=\"answer\"),\n    ]\n\n\ndef test_qwen3_5_streaming_matches_batch_with_prefilled_think(qwen3_5_tokenizer, qwen3_5_renderer):\n    \"\"\"Streaming and batch should produce identical results for prefilled think tokens.\"\"\"\n    response_tokens = qwen3_5_tokenizer.encode(\n        \"step 1\\nstep 2\\n</think>\\n\\nThe result is 42.<|im_end|>\",\n        add_special_tokens=False,\n    )\n\n    batch_message, batch_success = qwen3_5_renderer.parse_response(response_tokens)\n    assert batch_success\n\n    deltas = list(qwen3_5_renderer.parse_response_streaming(response_tokens))\n    streaming_message = deltas[-1]\n\n    assert _is_message(streaming_message)\n    assert streaming_message[\"role\"] == batch_message[\"role\"]\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n\n\ndef test_qwen3_5_normalize_noop_when_think_present(qwen3_5_tokenizer, qwen3_5_renderer):\n    \"\"\"When response already starts with <think>\\\\n, normalization is a no-op.\"\"\"\n    response_tokens = qwen3_5_tokenizer.encode(\n        \"<think>\\nreasoning\\n</think>\\n\\nanswer<|im_end|>\",\n        add_special_tokens=False,\n    )\n\n    # Both paths should work identically\n    batch_message, batch_success = qwen3_5_renderer.parse_response(response_tokens)\n    assert batch_success\n\n    deltas = list(qwen3_5_renderer.parse_response_streaming(response_tokens))\n    streaming_message = deltas[-1]\n\n    assert _is_message(streaming_message)\n    assert ensure_list(streaming_message[\"content\"]) == ensure_list(batch_message[\"content\"])\n"
  },
  {
    "path": "tinker_cookbook/renderers/qwen3_tool_declaration_test.py",
    "content": "\"\"\"Tests for Qwen tool declaration format compatibility with HuggingFace.\n\nThese tests verify that Qwen-family renderers produce identical tool declarations\nto HuggingFace's chat templates when using the tools parameter.\n\"\"\"\n\nimport json\nfrom collections.abc import Mapping, Sequence\n\nimport pytest\nfrom transformers import AutoTokenizer\n\nfrom tinker_cookbook.renderers import get_renderer\nfrom tinker_cookbook.renderers.base import Message, ToolSpec, ensure_text\nfrom tinker_cookbook.renderers.testing_utils import extract_token_ids\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n# Qwen3 models use JSON tool calls with OpenAI-style tool wrapper in tool declarations.\nQWEN3_MODELS = [\n    (\"Qwen/Qwen3-30B-A3B\", \"qwen3\"),\n    (\"Qwen/Qwen3-30B-A3B-Instruct-2507\", \"qwen3_instruct\"),\n]\n\n# Qwen3.5 models use XML tool calls and raw function specs in tool declarations.\nQWEN3_5_MODELS = [\n    (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5\"),\n    (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5_disable_thinking\"),\n]\n\nALL_QWEN_MODELS = QWEN3_MODELS + QWEN3_5_MODELS\n\n\ndef _hf_tools_for_model(\n    model_name: str, tools_toolspec: list[ToolSpec]\n) -> Sequence[Mapping[str, object]]:\n    \"\"\"Build the tools payload matching each model family's HF chat-template contract.\"\"\"\n    if \"Qwen3.5\" in model_name:\n        return list(tools_toolspec)\n    return [{\"type\": \"function\", \"function\": tool} for tool in tools_toolspec]\n\n\ndef _hf_template_kwargs(renderer_name: str) -> dict:\n    \"\"\"Return renderer-specific kwargs for HF apply_chat_template.\"\"\"\n    if renderer_name == \"qwen3_5_disable_thinking\":\n        return {\"enable_thinking\": False}\n    return {}\n\n\n@pytest.mark.parametrize(\"model_name,renderer_name\", QWEN3_MODELS)\ndef test_qwen3_tool_json_formatting(model_name: str, renderer_name: str):\n    \"\"\"Test that Qwen3 tool JSON uses correct separators to match HF.\n\n    HF's tojson filter uses:\n    - separators=(', ', ': ') with spaces after colons/commas\n    - No key sorting (preserves insertion order)\n    \"\"\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    # Tools with nested structure\n    tools: list[ToolSpec] = [\n        {\n            \"name\": \"search\",\n            \"description\": \"Search the web\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"query\": {\"type\": \"string\", \"description\": \"Search query\"},\n                    \"max_results\": {\"type\": \"integer\", \"description\": \"Max results\"},\n                },\n                \"required\": [\"query\"],\n            },\n        }\n    ]\n\n    messages = renderer.create_conversation_prefix_with_tools(tools, \"\")\n    system_msg = messages[0]\n    content_str = ensure_text(system_msg[\"content\"])\n\n    # Extract the JSON from the <tools>...</tools> section\n    start_marker = \"<tools>\\n\"\n    end_marker = \"\\n</tools>\"\n    start_idx = content_str.index(start_marker) + len(start_marker)\n    end_idx = content_str.index(end_marker)\n    tool_json_str = content_str[start_idx:end_idx]\n\n    # Parse to verify it's valid JSON\n    _ = json.loads(tool_json_str)\n\n    # Re-serialize with HF-compatible settings (no sort_keys)\n    expected_json = json.dumps(\n        {\"type\": \"function\", \"function\": tools[0]},\n        separators=(\", \", \": \"),\n    )\n\n    # Check formatting\n    assert tool_json_str == expected_json, (\n        f\"JSON formatting doesn't match HF expectations.\\n\"\n        f\"Expected (HF format):\\n{expected_json}\\n\\n\"\n        f\"Got (cookbook):\\n{tool_json_str}\\n\\n\"\n        f\"Differences:\\n\"\n        f\"  - Separators: HF uses (', ', ': ') with spaces\\n\"\n        f\"  - Key order: HF preserves insertion order (no sorting)\"\n    )\n\n    # Verify JSON uses spaces after colons\n    assert '\": ' in tool_json_str, \"JSON should have space after colons\"\n    assert '\", ' in tool_json_str, \"JSON should have space after commas\"\n\n\n@pytest.mark.parametrize(\"model_name,renderer_name\", ALL_QWEN_MODELS)\ndef test_qwen3_tool_declaration_matches_hf_tokens(model_name: str, renderer_name: str):\n    \"\"\"Test that tool declaration produces identical tokens to HuggingFace.\"\"\"\n    tokenizer = get_tokenizer(model_name)\n    hf_tokenizer = AutoTokenizer.from_pretrained(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    # Define tools in ToolSpec format (what tinker-cookbook accepts)\n    tools_toolspec: list[ToolSpec] = [\n        {\n            \"name\": \"get_weather\",\n            \"description\": \"Get current weather\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\"type\": \"string\", \"description\": \"City name\"},\n                    \"units\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n                },\n                \"required\": [\"location\"],\n            },\n        }\n    ]\n\n    tools_for_hf = _hf_tools_for_model(model_name, tools_toolspec)\n\n    messages_list: list[Message] = [{\"role\": \"user\", \"content\": \"What's the weather in SF?\"}]\n\n    # Tinker-cookbook approach\n    convo = renderer.create_conversation_prefix_with_tools(tools_toolspec, \"\") + messages_list\n    cookbook_tokens = renderer.build_generation_prompt(convo).to_ints()\n\n    # HuggingFace approach\n    hf_tokens = hf_tokenizer.apply_chat_template(\n        messages_list,\n        tools=tools_for_hf,\n        tokenize=True,\n        add_generation_prompt=True,\n        **_hf_template_kwargs(renderer_name),\n    )\n\n    hf_tokens_list = extract_token_ids(hf_tokens)\n\n    assert cookbook_tokens == hf_tokens_list, (\n        f\"Token mismatch between cookbook and HF!\\n\"\n        f\"Cookbook tokens ({len(cookbook_tokens)}): {cookbook_tokens}\\n\"\n        f\"Cookbook string:\\n{tokenizer.decode(cookbook_tokens)}\\n\\n\"\n        f\"HF tokens ({len(hf_tokens_list)}): {hf_tokens_list}\\n\"\n        f\"HF string:\\n{hf_tokenizer.decode(hf_tokens_list)}\"\n    )\n\n\n@pytest.mark.parametrize(\"model_name,renderer_name\", ALL_QWEN_MODELS)\ndef test_qwen3_tool_declaration_string_matches_hf(model_name: str, renderer_name: str):\n    \"\"\"Test that tool declaration produces identical string to HuggingFace.\"\"\"\n    tokenizer = get_tokenizer(model_name)\n    hf_tokenizer = AutoTokenizer.from_pretrained(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    tools_toolspec: list[ToolSpec] = [\n        {\n            \"name\": \"calculate\",\n            \"description\": \"Perform calculation\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"expression\": {\"type\": \"string\"},\n                },\n                \"required\": [\"expression\"],\n            },\n        }\n    ]\n\n    tools_for_hf = _hf_tools_for_model(model_name, tools_toolspec)\n    messages_list: list[Message] = [{\"role\": \"user\", \"content\": \"What is 2+2?\"}]\n\n    # Tinker-cookbook approach\n    convo = renderer.create_conversation_prefix_with_tools(tools_toolspec, \"\") + messages_list\n    cookbook_tokens = renderer.build_generation_prompt(convo).to_ints()\n    cookbook_string = tokenizer.decode(cookbook_tokens)\n\n    # HuggingFace approach\n    hf_string = hf_tokenizer.apply_chat_template(\n        messages_list,\n        tools=tools_for_hf,\n        tokenize=False,\n        add_generation_prompt=True,\n        **_hf_template_kwargs(renderer_name),\n    )\n\n    assert cookbook_string == hf_string, (\n        f\"String mismatch between cookbook and HF!\\n\"\n        f\"=== COOKBOOK ===\\n{cookbook_string}\\n\\n\"\n        f\"=== HF ===\\n{hf_string}\"\n    )\n\n\n@pytest.mark.parametrize(\"model_name,renderer_name\", ALL_QWEN_MODELS)\ndef test_qwen3_multiple_tools(model_name: str, renderer_name: str):\n    \"\"\"Test that multiple tools are formatted correctly.\"\"\"\n    tokenizer = get_tokenizer(model_name)\n    hf_tokenizer = AutoTokenizer.from_pretrained(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    tools_toolspec: list[ToolSpec] = [\n        {\n            \"name\": \"get_weather\",\n            \"description\": \"Get weather\",\n            \"parameters\": {\"type\": \"object\", \"properties\": {}},\n        },\n        {\n            \"name\": \"get_time\",\n            \"description\": \"Get time\",\n            \"parameters\": {\"type\": \"object\", \"properties\": {}},\n        },\n    ]\n\n    tools_for_hf = _hf_tools_for_model(model_name, tools_toolspec)\n    messages_list: list[Message] = [{\"role\": \"user\", \"content\": \"Hello\"}]\n\n    # Tinker-cookbook approach\n    convo = renderer.create_conversation_prefix_with_tools(tools_toolspec, \"\") + messages_list\n    cookbook_tokens = renderer.build_generation_prompt(convo).to_ints()\n\n    # HuggingFace approach\n    hf_tokens = hf_tokenizer.apply_chat_template(\n        messages_list,\n        tools=tools_for_hf,\n        tokenize=True,\n        add_generation_prompt=True,\n        **_hf_template_kwargs(renderer_name),\n    )\n\n    hf_tokens_list = extract_token_ids(hf_tokens)\n\n    assert cookbook_tokens == hf_tokens_list, (\n        f\"Token mismatch with multiple tools!\\n\"\n        f\"Cookbook: {tokenizer.decode(cookbook_tokens)}\\n\\n\"\n        f\"HF: {hf_tokenizer.decode(hf_tokens_list)}\"\n    )\n\n\n@pytest.mark.parametrize(\"model_name,renderer_name\", ALL_QWEN_MODELS)\ndef test_qwen3_empty_tools_list(model_name: str, renderer_name: str):\n    \"\"\"Test that empty tools list doesn't include tool section.\"\"\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    messages = renderer.create_conversation_prefix_with_tools([], \"\")\n\n    # Should return a system message with just the default system prompt (or empty)\n    assert len(messages) == 1\n    assert messages[0][\"role\"] == \"system\"\n    # Should not contain tool-related text\n    assert \"<tools>\" not in messages[0][\"content\"]\n\n\n@pytest.mark.parametrize(\"model_name,renderer_name\", ALL_QWEN_MODELS)\ndef test_qwen3_custom_system_prompt_with_tools(model_name: str, renderer_name: str):\n    \"\"\"Test that custom system prompt is combined with tools.\"\"\"\n    tokenizer = get_tokenizer(model_name)\n    hf_tokenizer = AutoTokenizer.from_pretrained(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    custom_prompt = \"You are a helpful assistant.\"\n    tools_toolspec: list[ToolSpec] = [\n        {\n            \"name\": \"search\",\n            \"description\": \"Search\",\n            \"parameters\": {\"type\": \"object\", \"properties\": {}},\n        }\n    ]\n\n    tools_for_hf = _hf_tools_for_model(model_name, tools_toolspec)\n    messages_list: list[Message] = [{\"role\": \"user\", \"content\": \"Help me\"}]\n\n    # Tinker-cookbook approach\n    convo = (\n        renderer.create_conversation_prefix_with_tools(tools_toolspec, custom_prompt)\n        + messages_list\n    )\n    cookbook_tokens = renderer.build_generation_prompt(convo).to_ints()\n\n    # HuggingFace approach - need to manually add system message\n    hf_messages = [{\"role\": \"system\", \"content\": custom_prompt}] + messages_list\n    hf_tokens = hf_tokenizer.apply_chat_template(\n        hf_messages,\n        tools=tools_for_hf,\n        tokenize=True,\n        add_generation_prompt=True,\n        **_hf_template_kwargs(renderer_name),\n    )\n\n    hf_tokens_list = extract_token_ids(hf_tokens)\n\n    assert cookbook_tokens == hf_tokens_list, (\n        f\"Token mismatch with custom system prompt!\\n\"\n        f\"Cookbook: {tokenizer.decode(cookbook_tokens)}\\n\\n\"\n        f\"HF: {hf_tokenizer.decode(hf_tokens_list)}\"\n    )\n\n\n@pytest.mark.parametrize(\"model_name,renderer_name\", ALL_QWEN_MODELS)\ndef test_qwen3_preserves_insertion_order(model_name: str, renderer_name: str):\n    \"\"\"Test that JSON keys preserve insertion order (not sorted).\"\"\"\n    tokenizer = get_tokenizer(model_name)\n    hf_tokenizer = AutoTokenizer.from_pretrained(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    # Tool with properties in specific order\n    tools_toolspec: list[ToolSpec] = [\n        {\n            \"name\": \"complex_tool\",\n            \"description\": \"A complex tool\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"zebra\": {\"type\": \"string\"},\n                    \"apple\": {\"type\": \"string\"},\n                },\n            },\n        }\n    ]\n\n    tools_for_hf = _hf_tools_for_model(model_name, tools_toolspec)\n    messages_list: list[Message] = [{\"role\": \"user\", \"content\": \"Test\"}]\n\n    # Tinker-cookbook approach\n    convo = renderer.create_conversation_prefix_with_tools(tools_toolspec, \"\") + messages_list\n    cookbook_tokens = renderer.build_generation_prompt(convo).to_ints()\n\n    # HuggingFace approach\n    hf_tokens = hf_tokenizer.apply_chat_template(\n        messages_list,\n        tools=tools_for_hf,\n        tokenize=True,\n        add_generation_prompt=True,\n        **_hf_template_kwargs(renderer_name),\n    )\n\n    # Should match exactly (HF doesn't sort, preserves insertion order)\n    hf_tokens_list = extract_token_ids(hf_tokens)\n\n    assert cookbook_tokens == hf_tokens_list, (\n        f\"Token mismatch - key ordering issue!\\n\"\n        f\"Cookbook: {tokenizer.decode(cookbook_tokens)}\\n\\n\"\n        f\"HF: {hf_tokenizer.decode(hf_tokens_list)}\"\n    )\n"
  },
  {
    "path": "tinker_cookbook/renderers/renderer_pickle_test.py",
    "content": "\"\"\"Tests for picklability of Renderers.\n\nRenderers created via get_renderer() must survive pickle roundtrips so that\nEnvGroupBuilder instances (which often hold Renderer references) can be\nserialized for distributed rollout execution.\n\"\"\"\n\nimport pickle\n\nimport pytest\n\nfrom tinker_cookbook.renderers import get_renderer, register_renderer, unregister_renderer\nfrom tinker_cookbook.renderers.base import Renderer\nfrom tinker_cookbook.tokenizer_utils import Tokenizer, get_tokenizer\n\n# Models that don't require special access / are commonly available in CI.\n# Each entry is (renderer_name, model_name).\n_TEXT_RENDERERS = [\n    (\"role_colon\", \"meta-llama/Llama-3.1-8B-Instruct\"),\n    (\"llama3\", \"meta-llama/Llama-3.1-8B-Instruct\"),\n    (\"qwen3\", \"Qwen/Qwen3-8B\"),\n    (\"qwen3_disable_thinking\", \"Qwen/Qwen3-8B\"),\n    (\"qwen3_instruct\", \"Qwen/Qwen3-8B\"),\n    (\"deepseekv3\", \"deepseek-ai/DeepSeek-V3-0324\"),\n    (\"deepseekv3_disable_thinking\", \"deepseek-ai/DeepSeek-V3-0324\"),\n    (\"deepseekv3_thinking\", \"deepseek-ai/DeepSeek-V3-0324\"),\n]\n\n\n@pytest.fixture(params=_TEXT_RENDERERS, ids=[r[0] for r in _TEXT_RENDERERS])\ndef renderer_and_model(request: pytest.FixtureRequest) -> tuple[str, str]:\n    return request.param\n\n\nclass TestRendererPickle:\n    def test_pickle_roundtrip(self, renderer_and_model: tuple[str, str]) -> None:\n        \"\"\"Renderers created via get_renderer() survive pickle roundtrip.\"\"\"\n        renderer_name, model_name = renderer_and_model\n        tokenizer = get_tokenizer(model_name)\n        renderer = get_renderer(renderer_name, tokenizer)\n\n        restored = pickle.loads(pickle.dumps(renderer))\n\n        assert restored._renderer_name == renderer_name\n        assert restored._model_name == renderer._model_name\n        assert type(restored) is type(renderer)\n        assert restored.get_stop_sequences() == renderer.get_stop_sequences()\n\n    def test_pickle_metadata_set(self, renderer_and_model: tuple[str, str]) -> None:\n        \"\"\"get_renderer() stamps _renderer_name and _model_name.\"\"\"\n        renderer_name, model_name = renderer_and_model\n        tokenizer = get_tokenizer(model_name)\n        renderer = get_renderer(renderer_name, tokenizer)\n\n        assert renderer._renderer_name == renderer_name\n        # _model_name may differ from model_name due to tokenizer remapping (e.g., Llama 3)\n        assert renderer._model_name == tokenizer.name_or_path\n\n    def test_pickle_without_metadata_raises(self) -> None:\n        \"\"\"Renderers created directly (not via get_renderer()) raise on pickle.\"\"\"\n        tokenizer = get_tokenizer(\"meta-llama/Llama-3.1-8B-Instruct\")\n\n        from tinker_cookbook.renderers.llama3 import Llama3Renderer\n\n        renderer = Llama3Renderer(tokenizer)\n        # _renderer_name and _model_name are None\n        with pytest.raises(pickle.PicklingError, match=\"not set\"):\n            pickle.dumps(renderer)\n\n    def test_pickle_with_manual_metadata(self) -> None:\n        \"\"\"Manually setting pickle metadata works for direct-constructed renderers.\"\"\"\n        tokenizer = get_tokenizer(\"meta-llama/Llama-3.1-8B-Instruct\")\n\n        from tinker_cookbook.renderers.llama3 import Llama3Renderer\n\n        renderer = Llama3Renderer(tokenizer)\n        renderer._renderer_name = \"llama3\"\n        renderer._model_name = \"meta-llama/Llama-3.1-8B-Instruct\"\n        renderer._has_image_processor = False\n\n        restored = pickle.loads(pickle.dumps(renderer))\n        assert type(restored) is Llama3Renderer\n        assert restored.get_stop_sequences() == renderer.get_stop_sequences()\n\n    def test_pickle_without_metadata_vl_renderer(self) -> None:\n        \"\"\"VL renderers that bypass super().__init__() still raise clean PicklingError.\"\"\"\n        tokenizer = get_tokenizer(\"Qwen/Qwen3-8B\")\n\n        from tinker_cookbook.renderers.qwen3 import Qwen3VLRenderer\n\n        # Qwen3VLRenderer bypasses super().__init__(), so _renderer_name is never\n        # set via __init__. Class-level defaults + getattr in __reduce__ handle this.\n        renderer = Qwen3VLRenderer(tokenizer, image_processor=None)\n        with pytest.raises(pickle.PicklingError, match=\"not set\"):\n            pickle.dumps(renderer)\n\n    def test_pickle_with_explicit_model_name(self) -> None:\n        \"\"\"The model_name param in get_renderer() overrides tokenizer.name_or_path.\"\"\"\n        tokenizer = get_tokenizer(\"meta-llama/Llama-3.1-8B-Instruct\")\n        # tokenizer.name_or_path is remapped, but we can override it\n        renderer = get_renderer(\"llama3\", tokenizer, model_name=\"meta-llama/Llama-3.1-8B-Instruct\")\n\n        assert renderer._model_name == \"meta-llama/Llama-3.1-8B-Instruct\"\n\n        restored = pickle.loads(pickle.dumps(renderer))\n        assert restored._model_name == \"meta-llama/Llama-3.1-8B-Instruct\"\n        assert type(restored) is type(renderer)\n\n    def test_pickle_custom_registered_renderer(self) -> None:\n        \"\"\"Custom renderers registered via register_renderer() are pickle-safe.\"\"\"\n        from tinker_cookbook.renderers.role_colon import RoleColonRenderer\n\n        def my_factory(tokenizer: Tokenizer, image_processor: object = None) -> Renderer:\n            return RoleColonRenderer(tokenizer)\n\n        register_renderer(\"test_custom_pickle\", my_factory)\n        try:\n            tokenizer = get_tokenizer(\"meta-llama/Llama-3.1-8B-Instruct\")\n            renderer = get_renderer(\"test_custom_pickle\", tokenizer)\n\n            assert renderer._renderer_name == \"test_custom_pickle\"\n\n            restored = pickle.loads(pickle.dumps(renderer))\n            assert type(restored) is RoleColonRenderer\n            assert restored._renderer_name == \"test_custom_pickle\"\n        finally:\n            unregister_renderer(\"test_custom_pickle\")\n\n\nclass TestMessageCompleterPickle:\n    def test_tinker_message_completer_pickle_structure(self) -> None:\n        \"\"\"TinkerMessageCompleter fields are individually pickleable (Renderer + SamplingClient).\n\n        We test the Renderer part here; SamplingClient has its own __reduce__ in the SDK.\n        \"\"\"\n        tokenizer = get_tokenizer(\"meta-llama/Llama-3.1-8B-Instruct\")\n        renderer = get_renderer(\"llama3\", tokenizer)\n\n        # Just verify the renderer component pickles fine when it would be inside a completer\n        restored_renderer = pickle.loads(pickle.dumps(renderer))\n        assert type(restored_renderer) is type(renderer)\n        assert restored_renderer.get_stop_sequences() == renderer.get_stop_sequences()\n"
  },
  {
    "path": "tinker_cookbook/renderers/renderers_test.py",
    "content": "\"\"\"\nTests for tinker_cookbook renderers against HuggingFace chat templates.\n\nThese tests verify that tinker-cookbook renderers produce identical token sequences\nto HuggingFace's chat templates. This is important because:\n\n1. The OpenAI-compatible inference endpoint (/chat/completions) uses HuggingFace\n   chat templates to render conversations to tokens.\n2. Users who train with tinker-cookbook and want to use the OpenAI endpoint for\n   inference need their training to use HF-compatible rendering.\n\nFor models with thinking capabilities (Qwen3, DeepSeek), we test both the default\nrenderer (thinking enabled) and the disable_thinking variant.\n\nSee docs/rendering.mdx for more details on the rendering system.\nSee docs/compatible-apis/openai.mdx for the OpenAI-compatible endpoint documentation.\n\nTesting guidelines:\n- Don't test things that are clearly verified by HF equivalence tests (build_generation_prompt,\n  build_supervised_example with basic conversations). HF equivalence tests ensure correctness.\n- DO test parse_response and parsing logic - HF doesn't do parsing, so we need those tests.\n- Keep tests focused on tricky logic, not trivial operations.\n\"\"\"\n\nimport copy\nimport json\nimport random\nimport uuid\nfrom collections.abc import Callable\n\nimport pytest\n\nfrom tinker_cookbook.image_processing_utils import get_image_processor\nfrom tinker_cookbook.model_info import get_model_attributes, get_recommended_renderer_name\nfrom tinker_cookbook.renderers import (\n    DeepSeekV3ThinkingRenderer,\n    Message,\n    Qwen3Renderer,\n    TextPart,\n    ThinkingPart,\n    ToolCall,\n    TrainOnWhat,\n    get_registered_renderer_names,\n    get_renderer,\n    is_renderer_registered,\n    register_renderer,\n    unregister_renderer,\n)\nfrom tinker_cookbook.renderers.base import ContentPart, ensure_list, ensure_text\nfrom tinker_cookbook.renderers.kimi_k2 import KimiK2Renderer\nfrom tinker_cookbook.renderers.kimi_k25 import KimiK25Renderer\nfrom tinker_cookbook.renderers.nemotron3 import Nemotron3Renderer\nfrom tinker_cookbook.renderers.qwen3_5 import Qwen3_5DisableThinkingRenderer, Qwen3_5Renderer\nfrom tinker_cookbook.renderers.testing_utils import (\n    extract_token_ids,\n    skip_if_deepseek_tokenizer_bug,\n)\nfrom tinker_cookbook.tokenizer_utils import (\n    get_registered_tokenizer_names,\n    get_tokenizer,\n    is_tokenizer_registered,\n    register_tokenizer,\n    unregister_tokenizer,\n)\n\n# =============================================================================\n# Conversation Generator (seeded random conversations for parametrized tests)\n# =============================================================================\n\n\ndef _rand_str(rng: random.Random, length: int = 8) -> str:\n    return uuid.UUID(int=rng.getrandbits(128)).hex[:length]\n\n\ndef _rand_tool_call(rng: random.Random) -> ToolCall:\n    return ToolCall(\n        function=ToolCall.FunctionBody(\n            name=f\"tool_{_rand_str(rng, 6)}\",\n            arguments=json.dumps({\"arg\": _rand_str(rng)}),\n        ),\n        id=f\"call_{_rand_str(rng)}\",\n    )\n\n\ndef generate_conversation(\n    seed: int,\n    *,\n    include_system: bool = True,\n    include_tool_calls: bool = True,\n    include_thinking: bool = True,\n    min_turns: int = 2,\n    max_turns: int = 10,\n    end_with_assistant: bool = True,\n) -> list[Message]:\n    rng = random.Random(seed)\n    messages: list[Message] = []\n\n    if include_system and rng.random() < 0.5:\n        messages.append(Message(role=\"system\", content=f\"system_{_rand_str(rng)}\"))\n\n    num_turns = rng.randint(min_turns, max_turns)\n\n    for turn in range(num_turns):\n        is_last_turn = turn == num_turns - 1\n        messages.append(Message(role=\"user\", content=f\"user_{_rand_str(rng)}\"))\n\n        if is_last_turn and not end_with_assistant:\n            break\n\n        has_thinking = include_thinking and rng.random() < 0.5\n        has_tool_call = include_tool_calls and rng.random() < 0.3\n\n        if has_thinking:\n            content: str | list[ContentPart] = [\n                ThinkingPart(type=\"thinking\", thinking=f\"think_{_rand_str(rng)}\"),\n                TextPart(type=\"text\", text=f\"asst_{_rand_str(rng)}\"),\n            ]\n        else:\n            content = f\"asst_{_rand_str(rng)}\"\n\n        if has_tool_call:\n            tool_call = _rand_tool_call(rng)\n            messages.append(Message(role=\"assistant\", content=content, tool_calls=[tool_call]))\n            assert tool_call.id is not None\n            messages.append(\n                Message(\n                    role=\"tool\",\n                    content=json.dumps({\"result\": _rand_str(rng)}),\n                    tool_call_id=tool_call.id,\n                    name=tool_call.function.name,\n                )\n            )\n            messages.append(Message(role=\"assistant\", content=f\"followup_{_rand_str(rng)}\"))\n        else:\n            messages.append(Message(role=\"assistant\", content=content))\n\n    return messages\n\n\n# =============================================================================\n# Test Conversation Definitions\n# =============================================================================\n# These functions provide reusable conversations for testing different scenarios.\n# Each returns a list of Message dicts that can be used across multiple tests.\n\n\ndef get_basic_3turn_conversation() -> list[Message]:\n    \"\"\"Simple 3-turn conversation: user -> assistant -> user.\n\n    This is the standard test case for generation prompt testing.\n    \"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        {\"role\": \"assistant\", \"content\": \"I'm fine, thank you!\"},\n        {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n    ]\n\n\ndef get_basic_2turn_conversation() -> list[Message]:\n    \"\"\"Simple 2-turn conversation: user -> assistant.\n\n    This is the standard test case for supervised example testing.\n    \"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        {\"role\": \"assistant\", \"content\": \"I'm fine, thank you!\"},\n    ]\n\n\ndef get_system_message_3turn_conversation() -> list[Message]:\n    \"\"\"3-turn conversation with a nontrivial system message.\n\n    Tests that renderers correctly handle system messages with real instructions.\n    \"\"\"\n    return [\n        {\n            \"role\": \"system\",\n            \"content\": \"You are a helpful coding assistant. Always explain your reasoning step by step.\",\n        },\n        {\"role\": \"user\", \"content\": \"How do I reverse a string in Python?\"},\n        {\"role\": \"assistant\", \"content\": \"You can use slicing: `s[::-1]` to reverse a string.\"},\n        {\"role\": \"user\", \"content\": \"Can you show me another way?\"},\n    ]\n\n\ndef get_system_message_2turn_conversation() -> list[Message]:\n    \"\"\"2-turn conversation with a nontrivial system message.\n\n    Tests that renderers correctly handle system messages with real instructions.\n    \"\"\"\n    return [\n        {\n            \"role\": \"system\",\n            \"content\": \"You are a helpful coding assistant. Always explain your reasoning step by step.\",\n        },\n        {\"role\": \"user\", \"content\": \"How do I reverse a string in Python?\"},\n        {\"role\": \"assistant\", \"content\": \"You can use slicing: `s[::-1]` to reverse a string.\"},\n    ]\n\n\ndef get_basic_4turn_conversation() -> list[Message]:\n    \"\"\"Simple 4-turn conversation: user -> assistant -> user -> assistant.\n\n    This is the standard test case for multi-turn supervised example testing.\n    \"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\"role\": \"assistant\", \"content\": \"The answer is 4.\"},\n        {\"role\": \"user\", \"content\": \"And what is 3+3?\"},\n        {\"role\": \"assistant\", \"content\": \"The answer is 6.\"},\n    ]\n\n\ndef get_tool_call_conversation() -> list[Message]:\n    \"\"\"Full tool use conversation with tool call and response.\n\n    Includes: user request -> assistant tool call -> tool response -> assistant final answer.\n    Ends with assistant (for supervised tests).\n    \"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"What's the weather in San Francisco?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": \"I'll check the weather for you.\",\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"San Francisco\"}',\n                    ),\n                    id=\"call_123\",\n                )\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"content\": '{\"temperature\": 72, \"condition\": \"sunny\"}',\n            \"tool_call_id\": \"call_123\",\n            \"name\": \"get_weather\",  # Required by GptOss, optional for others\n        },\n        {\"role\": \"assistant\", \"content\": \"The weather in San Francisco is sunny with 72°F.\"},\n    ]\n\n\ndef get_tool_call_gen_conversation() -> list[Message]:\n    \"\"\"Tool use conversation for generation testing (ends with tool response).\n\n    Tests generating assistant response after receiving tool output.\n    \"\"\"\n    return get_tool_call_conversation()[:-1]\n\n\ndef get_4turn_thinking_conversation() -> list[Message]:\n    \"\"\"4-turn conversation with ThinkingPart in assistant messages.\n\n    Tests thinking content handling in multi-turn conversations.\n    \"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"First turn reasoning here.\"),\n                TextPart(type=\"text\", text=\"The answer is 4.\"),\n            ],\n        },\n        {\"role\": \"user\", \"content\": \"And what is 3+3?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"Second turn reasoning here.\"),\n                TextPart(type=\"text\", text=\"The answer is 6.\"),\n            ],\n        },\n    ]\n\n\ndef get_thinking_with_whitespace_conversation() -> list[Message]:\n    \"\"\"Conversation with leading whitespace in ThinkingPart and TextPart.\"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"\\nLet me respond politely.\\n\"),\n                TextPart(type=\"text\", text=\"\\n\\nI'm fine, thank you!\"),\n            ],\n        },\n    ]\n\n\ndef get_multiturn_thinking_conversation() -> list[Message]:\n    \"\"\"Multi-turn conversation with thinking in assistant messages.\n\n    Used for testing extension property with thinking content.\n    \"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"Let me add 2+2.\"),\n                TextPart(type=\"text\", text=\"The answer is 4.\"),\n            ],\n        },\n        {\"role\": \"user\", \"content\": \"What is 3+3?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"Let me add 3+3.\"),\n                TextPart(type=\"text\", text=\"The answer is 6.\"),\n            ],\n        },\n    ]\n\n\ndef get_multiturn_tool_conversation() -> list[Message]:\n    \"\"\"Multi-turn conversation with tool calls.\n\n    Used for testing extension property with tool calling.\n    \"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": \"Let me check the weather for you.\",\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"content\": '{\"temperature\": 72, \"condition\": \"sunny\"}',\n            \"tool_call_id\": \"call_1\",\n            \"name\": \"get_weather\",\n        },\n        {\"role\": \"assistant\", \"content\": \"The weather in NYC is sunny with 72°F.\"},\n        {\"role\": \"user\", \"content\": \"What about San Francisco?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": \"Let me check SF weather.\",\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"San Francisco\"}',\n                    ),\n                    id=\"call_2\",\n                )\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"content\": '{\"temperature\": 65, \"condition\": \"foggy\"}',\n            \"tool_call_id\": \"call_2\",\n            \"name\": \"get_weather\",\n        },\n        {\"role\": \"assistant\", \"content\": \"San Francisco is foggy at 65°F.\"},\n    ]\n\n\ndef get_multiturn_thinking_and_tool_conversation() -> list[Message]:\n    \"\"\"Multi-turn conversation with both thinking AND tool calls.\n\n    Tests complex interactions with both thinking blocks and tool calling.\n    \"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"What's the weather in NYC?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"I need to check the weather API.\"),\n                TextPart(type=\"text\", text=\"Let me look that up.\"),\n            ],\n            \"tool_calls\": [\n                ToolCall(\n                    function=ToolCall.FunctionBody(\n                        name=\"get_weather\",\n                        arguments='{\"location\": \"NYC\"}',\n                    ),\n                    id=\"call_1\",\n                )\n            ],\n        },\n        {\n            \"role\": \"tool\",\n            \"content\": '{\"temperature\": 72}',\n            \"tool_call_id\": \"call_1\",\n            \"name\": \"get_weather\",\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"The API returned 72 degrees.\"),\n                TextPart(type=\"text\", text=\"NYC is 72°F.\"),\n            ],\n        },\n        {\"role\": \"user\", \"content\": \"Is that warm?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"72°F is about 22°C, which is pleasant.\"),\n                TextPart(type=\"text\", text=\"Yes, 72°F is comfortable room temperature.\"),\n            ],\n        },\n    ]\n\n\n# Conversation registry for parametrized tests\n# Maps conversation ID to (factory_function, description, requires_tools)\nCONVERSATION_REGISTRY: dict[str, tuple[Callable[[], list[Message]], str, bool]] = {\n    \"basic_3turn\": (get_basic_3turn_conversation, \"basic 3-turn conversation\", False),\n    \"basic_2turn\": (get_basic_2turn_conversation, \"basic 2-turn conversation\", False),\n    \"system_3turn\": (get_system_message_3turn_conversation, \"3-turn with system message\", False),\n    \"system_2turn\": (get_system_message_2turn_conversation, \"2-turn with system message\", False),\n    \"tool_call\": (get_tool_call_conversation, \"tool call conversation (supervised)\", True),\n    \"tool_call_gen\": (get_tool_call_gen_conversation, \"tool call conversation (generation)\", True),\n    \"thinking_4turn\": (get_4turn_thinking_conversation, \"4-turn with thinking\", False),\n    \"thinking_and_tool\": (\n        get_multiturn_thinking_and_tool_conversation,\n        \"multi-turn with thinking and tools\",\n        True,\n    ),\n    # Random conversations ending with user (for generation tests) - with tools/thinking\n    **{\n        f\"random_gen_{seed}\": (\n            lambda s=seed: generate_conversation(s, end_with_assistant=False),\n            f\"random gen conversation (seed={seed})\",\n            True,  # May contain tools\n        )\n        for seed in [1, 42, 123, 456, 999]\n    },\n    # Random conversations ending with assistant (for supervised tests) - with tools/thinking\n    **{\n        f\"random_sup_{seed}\": (\n            lambda s=seed: generate_conversation(s, end_with_assistant=True),\n            f\"random sup conversation (seed={seed})\",\n            True,  # May contain tools\n        )\n        for seed in [1, 42, 123, 456, 999]\n    },\n}\n\n\n# Models that support tool calling in their renderers\nTOOL_CAPABLE_MODELS = {\n    \"Qwen/Qwen3-30B-A3B\",\n    \"Qwen/Qwen3-30B-A3B-Instruct-2507\",\n    \"Qwen/Qwen3-VL-30B-A3B-Instruct\",\n    \"Qwen/Qwen3.5-35B-A3B\",\n    \"meta-llama/Llama-3.2-1B-Instruct\",\n    \"deepseek-ai/DeepSeek-V3.1\",\n    \"moonshotai/Kimi-K2-Thinking\",\n    \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n    \"openai/gpt-oss-20b\",\n}\n\n\n# =============================================================================\n# HF Compatibility Tests (parametrized by model and conversation)\n# =============================================================================\n\n\n# Models for HF generation/supervised tests\n# Format: (model_name, renderer_override, hf_kwargs)\n# - model_name: HuggingFace model ID\n# - renderer_override: None to use get_recommended_renderer_name, or a specific renderer name\n# - hf_kwargs: Extra kwargs to pass to apply_chat_template (e.g., {\"thinking\": True})\n_HF_TEST_MODELS = [\n    (\"meta-llama/Llama-3.2-1B-Instruct\", None, {}),\n    (\"Qwen/Qwen3-30B-A3B\", None, {}),\n    (\"Qwen/Qwen3-30B-A3B-Instruct-2507\", None, {}),\n    (\"deepseek-ai/DeepSeek-V3.1\", None, {}),  # non-thinking (default)\n    (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3_thinking\", {\"thinking\": True}),  # thinking mode\n    (\"openai/gpt-oss-20b\", None, {}),\n    (\"moonshotai/Kimi-K2-Thinking\", None, {}),\n    (\"Qwen/Qwen3-VL-30B-A3B-Instruct\", None, {}),\n    (\"Qwen/Qwen3.5-35B-A3B\", None, {}),\n    (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5_disable_thinking\", {\"enable_thinking\": False}),\n    (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", None, {}),\n    (\n        \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n        \"nemotron3_disable_thinking\",\n        {\"enable_thinking\": False},\n    ),\n]\n\n# Models whose tool call format matches HF's apply_chat_template exactly.\n# Excluded models with intentional differences:\n# - Llama3: see llama3.py docstring (double-encoding, assistant content handling)\n# - gpt-oss: no HF template\n_HF_TOOL_COMPATIBLE_MODELS = {\n    \"Qwen/Qwen3-30B-A3B\",\n    \"Qwen/Qwen3-30B-A3B-Instruct-2507\",\n    \"Qwen/Qwen3-VL-30B-A3B-Instruct\",\n    \"Qwen/Qwen3.5-35B-A3B\",\n    \"deepseek-ai/DeepSeek-V3.1\",\n    \"moonshotai/Kimi-K2-Thinking\",\n    \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n}\n\n# Conversations for generation tests (end with user message or tool response)\n# Note: Random conversations are tested in consistency tests, not HF comparison,\n# because they can have complex thinking+tool combinations with HF format quirks.\n_GENERATION_CONVERSATIONS = [\n    \"basic_3turn\",\n    \"system_3turn\",\n    \"tool_call_gen\",\n]\n\n\ndef _conversation_has_tools(messages: list[Message]) -> bool:\n    \"\"\"Check if a conversation contains tool calls or tool responses.\"\"\"\n    return any(\"tool_calls\" in m or m[\"role\"] == \"tool\" for m in messages)\n\n\ndef _add_llama3_date_prefix(messages: list[Message]) -> list[Message]:\n    \"\"\"Add date prefix to messages for Llama models.\"\"\"\n    # Use the hardcoded date from the mirrored tokenizer's chat template\n    date_prefix = \"Cutting Knowledge Date: December 2023\\nToday Date: 26 Jul 2024\\n\\n\"\n    messages = copy.deepcopy(messages)\n    if messages and messages[0][\"role\"] == \"system\":\n        assert isinstance(messages[0][\"content\"], str)\n        messages[0][\"content\"] = date_prefix + messages[0][\"content\"]\n    else:\n        messages = [Message(role=\"system\", content=date_prefix)] + messages\n    return messages\n\n\n@pytest.mark.parametrize(\"conv_id\", _GENERATION_CONVERSATIONS)\n@pytest.mark.parametrize(\"model_name,renderer_override,hf_kwargs\", _HF_TEST_MODELS)\ndef test_generation_against_hf_chat_templates(\n    model_name: str, renderer_override: str | None, hf_kwargs: dict, conv_id: str\n):\n    \"\"\"Test generation prompt against HF chat templates.\n\n    Parametrized by model and conversation type. Tests that our renderer produces\n    identical tokens to HuggingFace's chat template for the same conversation.\n    \"\"\"\n    conv_factory, conv_desc, requires_tools = CONVERSATION_REGISTRY[conv_id]\n    convo = conv_factory()\n\n    # Skip tool-containing conversations for models with intentional HF differences\n    if _conversation_has_tools(convo) and model_name not in _HF_TOOL_COMPATIBLE_MODELS:\n        pytest.skip(\n            f\"{model_name} has intentional tool format differences from HF (see renderer docstring)\"\n        )\n\n    tokenizer = get_tokenizer(model_name)\n    attributes = get_model_attributes(model_name)\n    image_processor = get_image_processor(model_name) if attributes.is_vl else None\n\n    # Use renderer_override if provided, otherwise use default logic\n    if renderer_override is not None:\n        render_name = renderer_override\n    elif model_name.startswith(\"openai\"):\n        render_name = \"gpt_oss_medium_reasoning\"\n    else:\n        render_name = get_recommended_renderer_name(model_name)\n    cookbook_renderer = get_renderer(render_name, tokenizer, image_processor)\n\n    modified_cookbook_convo = (\n        _add_llama3_date_prefix(convo) if model_name.startswith(\"meta\") else convo\n    )\n    # ^^^ modify the cookbook convo just for llama3, where we chose not to match the HF template\n\n    # Extract tools from tool_declare messages and filter them out when converting to OpenAI format\n    tools_for_hf = None\n    hf_convo = []\n    for m in convo:\n        if m[\"role\"] == \"tool_declare\":\n            # Parse the JSON content to extract tools for HF\n            tools_for_hf = json.loads(ensure_text(m[\"content\"]))\n        else:\n            openai_msg = cookbook_renderer.to_openai_message(m)\n            hf_convo.append(openai_msg)\n\n    cookbook_tokens = cookbook_renderer.build_generation_prompt(modified_cookbook_convo).to_ints()\n    hf_tokens = tokenizer.apply_chat_template(\n        hf_convo, tools=tools_for_hf, add_generation_prompt=True, tokenize=True, **hf_kwargs\n    )\n\n    hf_tokens_list = extract_token_ids(hf_tokens)\n\n    assert cookbook_tokens == hf_tokens_list, (\n        f\"[{conv_desc}] Cookbook tokens: {cookbook_tokens}\\n\"\n        f\"Cookbook string: {tokenizer.decode(cookbook_tokens)}\\n\"\n        f\"HF tokens: {hf_tokens_list}\\n\"\n        f\"HF string: {tokenizer.decode(hf_tokens_list)}\"\n    )\n\n\n# Models for supervised tests\n# Excluded:\n# - gpt-oss: analysis channel diverges from HF template\n# - Qwen/Qwen3-30B-A3B: HF template adds empty <think> blocks to non-thinking messages\n# Format: (model_name, renderer_override, hf_kwargs) - same as _HF_TEST_MODELS\n_SUPERVISED_TEST_MODELS = [\n    (\"meta-llama/Llama-3.2-1B-Instruct\", None, {}),\n    (\"Qwen/Qwen3-30B-A3B-Instruct-2507\", None, {}),\n    (\"deepseek-ai/DeepSeek-V3.1\", None, {}),  # non-thinking (default)\n    (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3_thinking\", {\"thinking\": True}),  # thinking mode\n    (\"moonshotai/Kimi-K2-Thinking\", None, {}),\n    (\"Qwen/Qwen3-VL-30B-A3B-Instruct\", None, {}),\n    (\"Qwen/Qwen3.5-35B-A3B\", None, {}),\n    (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5_disable_thinking\", {\"enable_thinking\": False}),\n    (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", None, {}),\n    (\n        \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n        \"nemotron3_disable_thinking\",\n        {\"enable_thinking\": False},\n    ),\n]\n\n# Conversations for supervised tests (end with assistant message)\n# Note: Random conversations are tested in consistency tests, not HF comparison.\n_SUPERVISED_CONVERSATIONS = [\n    \"basic_2turn\",\n    \"system_2turn\",\n    \"tool_call\",\n]\n\n\n@pytest.mark.parametrize(\"conv_id\", _SUPERVISED_CONVERSATIONS)\n@pytest.mark.parametrize(\"model_name,renderer_override,hf_kwargs\", _SUPERVISED_TEST_MODELS)\ndef test_supervised_example_against_hf_chat_templates(\n    model_name: str, renderer_override: str | None, hf_kwargs: dict, conv_id: str\n):\n    \"\"\"Test supervised example against HF chat templates.\n\n    Parametrized by model and conversation type. Tests that our renderer produces\n    identical tokens to HuggingFace's chat template for the same conversation.\n    \"\"\"\n    conv_factory, conv_desc, requires_tools = CONVERSATION_REGISTRY[conv_id]\n    convo = conv_factory()\n\n    # Skip tool-containing conversations for models with intentional HF differences\n    if _conversation_has_tools(convo) and model_name not in _HF_TOOL_COMPATIBLE_MODELS:\n        pytest.skip(\n            f\"{model_name} has intentional tool format differences from HF (see renderer docstring)\"\n        )\n\n    # Skip supervised tests for thinking renderer - we intentionally don't add </think> to the\n    # last message (supervised target) so it can preserve ThinkingPart, unlike HF which always adds it\n    if renderer_override in _RENDERERS_WITH_DIFFERENT_SUPERVISED_GEN_HEADERS:\n        pytest.skip(\n            f\"{renderer_override} intentionally differs from HF for supervised target (no </think>)\"\n        )\n\n    tokenizer = get_tokenizer(model_name)\n    attributes = get_model_attributes(model_name)\n    image_processor = get_image_processor(model_name) if attributes.is_vl else None\n\n    # Use renderer_override if provided, otherwise use default logic\n    if renderer_override is not None:\n        render_name = renderer_override\n    elif model_name.startswith(\"openai\"):\n        render_name = \"gpt_oss_medium_reasoning\"\n    else:\n        render_name = get_recommended_renderer_name(model_name)\n    cookbook_renderer = get_renderer(render_name, tokenizer, image_processor)\n\n    modified_cookbook_convo = (\n        _add_llama3_date_prefix(convo) if model_name.startswith(\"meta\") else convo\n    )\n    # ^^^ modify the cookbook convo just for llama3, where we chose not to match the HF template\n\n    # Extract tools from tool_declare messages and filter them out when converting to OpenAI format\n    tools_for_hf = None\n    hf_convo = []\n    for m in convo:\n        if m[\"role\"] == \"tool_declare\":\n            # Parse the JSON content to extract tools for HF\n            tools_for_hf = json.loads(ensure_text(m[\"content\"]))\n        else:\n            openai_msg = cookbook_renderer.to_openai_message(m)\n            hf_convo.append(openai_msg)\n\n    cookbook_model_input, _ = cookbook_renderer.build_supervised_example(modified_cookbook_convo)\n    cookbook_tokens = cookbook_model_input.to_ints()\n    hf_output = tokenizer.apply_chat_template(\n        hf_convo, tools=tools_for_hf, tokenize=False, add_generation_prompt=False, **hf_kwargs\n    )\n    assert isinstance(hf_output, str)\n    hf_tokens = tokenizer.encode(hf_output.rstrip(\"\\n\"), add_special_tokens=False)\n\n    assert cookbook_tokens == hf_tokens, (\n        f\"[{conv_desc}] Cookbook tokens: {cookbook_tokens}\\n\"\n        f\"Cookbook string: {tokenizer.decode(cookbook_tokens)}\\n\"\n        f\"HF tokens: {hf_tokens}\\n\"\n        f\"HF string: {tokenizer.decode(hf_tokens)}\"\n    )\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"Qwen/Qwen3-30B-A3B\",\n    ],\n)\ndef test_tokenization_boundary_with_whitespace(model_name: str):\n    \"\"\"Test that whitespace in ThinkingPart/TextPart tokenizes correctly vs HF.\n\n    Qwen3 is excluded from supervised HF tests (empty <think> blocks), so we\n    test the whitespace case separately here.\n    \"\"\"\n    convo = get_thinking_with_whitespace_conversation()\n\n    tokenizer = get_tokenizer(model_name)\n    attributes = get_model_attributes(model_name)\n    image_processor = get_image_processor(model_name) if attributes.is_vl else None\n    render_name = get_recommended_renderer_name(model_name)\n    cookbook_renderer = get_renderer(render_name, tokenizer, image_processor)\n\n    hf_convo = [cookbook_renderer.to_openai_message(m) for m in convo]\n\n    cookbook_model_input, _ = cookbook_renderer.build_supervised_example(convo)\n    cookbook_tokens = cookbook_model_input.to_ints()\n    hf_output = tokenizer.apply_chat_template(hf_convo, tokenize=False, add_generation_prompt=False)\n    assert isinstance(hf_output, str)\n    hf_tokens = tokenizer.encode(hf_output.rstrip(\"\\n\"), add_special_tokens=False)\n\n    # Verify decoded strings match (this should pass)\n    cookbook_decoded = tokenizer.decode(cookbook_tokens)\n    hf_decoded = tokenizer.decode(hf_tokens)\n    assert cookbook_decoded == hf_decoded, \"Decoded strings should match even if tokens differ\"\n\n    # Verify tokens match\n    assert cookbook_tokens == hf_tokens, (\n        f\"Token mismatch with whitespace content.\\n\"\n        f\"Cookbook tokens: {cookbook_tokens}\\n\"\n        f\"HF tokens: {hf_tokens}\\n\"\n        f\"Both decode to: {cookbook_decoded!r}\"\n    )\n\n\n# =============================================================================\n# Tool Use Rendering Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"Qwen/Qwen3-30B-A3B\",\n        \"Qwen/Qwen3.5-35B-A3B\",\n        # Llama3 does not support tool calling - see llama3.py docstring\n        \"deepseek-ai/DeepSeek-V3.1\",\n        \"moonshotai/Kimi-K2-Thinking\",\n        \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n        \"openai/gpt-oss-20b\",\n    ],\n)\ndef test_tool_call_supervised_rendering(model_name: str):\n    \"\"\"Test that tool call conversations render without errors.\n\n    Verifies that our renderers handle tool call conversations correctly\n    for supervised learning.\n    \"\"\"\n    skip_if_deepseek_tokenizer_bug(model_name)\n    convo = get_tool_call_conversation()\n\n    tokenizer = get_tokenizer(model_name)\n    attributes = get_model_attributes(model_name)\n    image_processor = get_image_processor(model_name) if attributes.is_vl else None\n    render_name = (\n        get_recommended_renderer_name(model_name)\n        if not model_name.startswith(\"openai\")\n        else \"gpt_oss_medium_reasoning\"\n    )\n    cookbook_renderer = get_renderer(render_name, tokenizer, image_processor)\n\n    # Build supervised example - should not raise\n    model_input, weights = cookbook_renderer.build_supervised_example(convo)\n    tokens = model_input.to_ints()\n    decoded = tokenizer.decode(tokens)\n\n    # Verify basic structure\n    assert len(tokens) > 0, \"Should produce non-empty token sequence\"\n    assert len(weights) == len(tokens), \"Weights should match token count\"\n\n    # Verify tool-related content appears in output\n    # Different renderers format tool calls differently:\n    # - Qwen3: <tool_call>{\"name\": \"get_weather\", ...}</tool_call>\n    # - Llama3: <function=get_weather>...</function>\n    # - DeepSeek: <｜tool▁sep｜>get_weather\n    # - Kimi K2: Uses tool_id (functions.name:idx or just the id) + arguments\n    # - GptOss: to=functions.get_weather<|channel|>commentary <|constrain|>json<|message|>{args}\n    # Check for tool arguments which all formats include\n    assert \"San Francisco\" in decoded, f\"Tool argument should appear in rendered output: {decoded}\"\n\n    # Check for either the function name or the tool_call_id\n    has_tool_indicator = \"get_weather\" in decoded or \"call_123\" in decoded\n    assert has_tool_indicator, f\"Tool name or ID should appear in rendered output: {decoded}\"\n\n\n# =============================================================================\n# Thinking Stripping Tests (multi-turn with thinking content)\n# =============================================================================\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_class\",\n    [\n        (\"Qwen/Qwen3-8B\", Qwen3Renderer),\n        (\"Qwen/Qwen3.5-35B-A3B\", Qwen3_5Renderer),\n        (\"deepseek-ai/DeepSeek-V3.1\", DeepSeekV3ThinkingRenderer),\n        (\"moonshotai/Kimi-K2-Thinking\", KimiK2Renderer),\n        (\"moonshotai/Kimi-K2.5\", KimiK25Renderer),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", Nemotron3Renderer),\n    ],\n)\ndef test_strip_thinking_from_history_default(model_name: str, renderer_class):\n    \"\"\"\n    Test that renderers with strip_thinking_from_history=True (default) only preserve\n    the last assistant message's thinking. Earlier assistant thinking blocks are stripped.\n    \"\"\"\n    skip_if_deepseek_tokenizer_bug(model_name)\n    tokenizer = get_tokenizer(model_name)\n    renderer = renderer_class(tokenizer)  # Default strip_thinking_from_history=True\n\n    messages = get_4turn_thinking_conversation()\n    model_input, _ = renderer.build_supervised_example(messages)\n    decoded = tokenizer.decode(model_input.to_ints())\n\n    # First assistant message should have thinking stripped\n    assert \"First turn reasoning\" not in decoded, (\n        f\"First turn thinking should be stripped:\\n{decoded}\"\n    )\n    # Second (last) assistant message should preserve thinking\n    assert \"Second turn reasoning\" in decoded, f\"Last turn thinking should be preserved:\\n{decoded}\"\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_class\",\n    [\n        (\"Qwen/Qwen3-8B\", Qwen3Renderer),\n        (\"Qwen/Qwen3.5-35B-A3B\", Qwen3_5Renderer),\n        (\"deepseek-ai/DeepSeek-V3.1\", DeepSeekV3ThinkingRenderer),\n        (\"moonshotai/Kimi-K2-Thinking\", KimiK2Renderer),\n        (\"moonshotai/Kimi-K2.5\", KimiK25Renderer),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", Nemotron3Renderer),\n    ],\n)\ndef test_strip_thinking_from_history_false(model_name: str, renderer_class):\n    \"\"\"\n    Test that strip_thinking_from_history=False preserves thinking in ALL messages.\n    This mode is used for multi-turn RL where the extension property is needed.\n    \"\"\"\n    skip_if_deepseek_tokenizer_bug(model_name)\n    tokenizer = get_tokenizer(model_name)\n    renderer = renderer_class(tokenizer, strip_thinking_from_history=False)\n\n    messages = get_4turn_thinking_conversation()\n    model_input, _ = renderer.build_supervised_example(messages)\n    decoded = tokenizer.decode(model_input.to_ints())\n\n    # Both thinking blocks should be present\n    assert \"First turn reasoning\" in decoded, (\n        f\"First thinking should be preserved with strip_thinking_from_history=False: {decoded}\"\n    )\n    assert \"Second turn reasoning\" in decoded, f\"Second thinking should be preserved: {decoded}\"\n\n\n# =============================================================================\n# Supervised/Generation/Parse Consistency Tests\n# =============================================================================\n\n\ndef _split_by_weights(tokens: list[int], weights: list[float]) -> tuple[list[int], list[int]]:\n    \"\"\"Split token sequence into observation (weight=0) and action (weight=1) parts.\n\n    Assumes weights are like 000...0111...1 (zeros then ones).\n    Returns (ob, ac) where ob has all weight=0 tokens and ac has all weight=1 tokens.\n    \"\"\"\n    assert len(tokens) == len(weights), (\n        f\"Token/weight length mismatch: {len(tokens)} vs {len(weights)}\"\n    )\n\n    # Find the first non-zero weight\n    first_nonzero = None\n    for i, w in enumerate(weights):\n        if w > 0:\n            first_nonzero = i\n            break\n\n    if first_nonzero is None:\n        # All zeros - no action tokens\n        return tokens, []\n\n    # Verify the pattern: all zeros before first_nonzero, all ones after\n    for i, w in enumerate(weights):\n        if i < first_nonzero:\n            assert w == 0, f\"Expected weight=0 at index {i}, got {w}\"\n        else:\n            assert w == 1, f\"Expected weight=1 at index {i}, got {w}\"\n\n    ob = tokens[:first_nonzero]\n    ac = tokens[first_nonzero:]\n    return ob, ac\n\n\ndef get_2turn_with_thinking() -> list[Message]:\n    \"\"\"2-turn conversation with thinking content in assistant message.\"\"\"\n    return [\n        {\"role\": \"user\", \"content\": \"Hello, how are you?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                ThinkingPart(type=\"thinking\", thinking=\"Let me respond politely.\"),\n                TextPart(type=\"text\", text=\"I'm fine, thank you!\"),\n            ],\n        },\n    ]\n\n\n# Renderers for the consistency test - (model_name, renderer_name)\n_CONSISTENCY_RENDERERS = [\n    (\"meta-llama/Llama-3.2-1B-Instruct\", \"llama3\"),\n    (\"meta-llama/Llama-3.2-1B-Instruct\", \"role_colon\"),\n    (\"Qwen/Qwen3-8B\", \"qwen3\"),\n    (\"Qwen/Qwen3-8B\", \"qwen3_disable_thinking\"),\n    (\"Qwen/Qwen3-8B\", \"qwen3_instruct\"),\n    (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5\"),\n    (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5_disable_thinking\"),\n    (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3\"),\n    (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3_thinking\"),\n    (\"openai/gpt-oss-20b\", \"gpt_oss_medium_reasoning\"),\n    (\"moonshotai/Kimi-K2-Thinking\", \"kimi_k2\"),\n    (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3\"),\n    (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3_disable_thinking\"),\n]\n\n# Conversations for the consistency test\n_CONSISTENCY_CONVERSATIONS = [\n    get_basic_2turn_conversation,\n    get_2turn_with_thinking,\n    get_tool_call_conversation,\n    # Random conversations with tools/thinking\n    lambda: generate_conversation(1, end_with_assistant=True),\n    lambda: generate_conversation(42, end_with_assistant=True),\n    lambda: generate_conversation(999, end_with_assistant=True),\n]\n\n\n# Renderers that don't support ThinkingPart content (use ensure_text)\n_RENDERERS_WITHOUT_THINKING_SUPPORT = {\"llama3\", \"role_colon\"}\n\n# Renderers that don't support tool calling\n_RENDERERS_WITHOUT_TOOL_SUPPORT = {\"role_colon\"}\n\n# Renderers that strip thinking in non-thinking mode (conversation must not have ThinkingPart)\n_RENDERERS_WITH_THINKING_STRIPPING = {\n    \"qwen3_disable_thinking\",\n    \"qwen3_5_disable_thinking\",\n    \"nemotron3_disable_thinking\",\n    \"deepseekv3\",\n    \"kimi_k2\",\n}\n\n# Renderers where supervised and generation have different headers (HF thinking=True behavior).\n# These add </think> to supervised assistant headers but <think> to generation prompt,\n# so observation != generation_prompt by design.\n_RENDERERS_WITH_DIFFERENT_SUPERVISED_GEN_HEADERS = {\"deepseekv3_thinking\", \"qwen3_5\", \"nemotron3\"}\n\n\n@pytest.mark.parametrize(\"conversation_fn\", _CONSISTENCY_CONVERSATIONS)\n@pytest.mark.parametrize(\"model_name,renderer_name\", _CONSISTENCY_RENDERERS)\ndef test_supervised_generation_parse_consistency(\n    model_name: str, renderer_name: str, conversation_fn\n):\n    \"\"\"Test consistency between build_supervised_example, build_generation_prompt, and parse_response.\n\n    For train_on_what=LAST_ASSISTANT_MESSAGE, this test verifies:\n    1. The supervised example produces weights like 000...0111...1\n    2. Split tokens into (ob, ac) based on weights\n    3. ob == build_generation_prompt(messages[:-1]).to_ints()\n    4. parse_response(ac) returns the final message\n\n    This ensures that:\n    - The observation tokens match what the model would see at generation time\n    - The action tokens can be parsed back to the original message\n    \"\"\"\n    skip_if_deepseek_tokenizer_bug(model_name)\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    messages = conversation_fn()\n\n    # Check if this combination is supported based on actual message content\n    def has_thinking(msgs):\n        for m in msgs:\n            content = m[\"content\"]\n            if isinstance(content, list):\n                for p in content:\n                    if p.get(\"type\") == \"thinking\":\n                        return True\n        return False\n\n    def has_tools(msgs):\n        return any(\"tool_calls\" in m or m[\"role\"] == \"tool\" for m in msgs)\n\n    has_thinking_content = has_thinking(messages)\n    has_tool_content = has_tools(messages)\n\n    if has_thinking_content and renderer_name in _RENDERERS_WITHOUT_THINKING_SUPPORT:\n        pytest.skip(f\"{renderer_name} doesn't support ThinkingPart content\")\n    if has_thinking_content and renderer_name in _RENDERERS_WITH_THINKING_STRIPPING:\n        pytest.skip(f\"{renderer_name} strips thinking content, breaking roundtrip consistency\")\n    if has_tool_content and renderer_name in _RENDERERS_WITHOUT_TOOL_SUPPORT:\n        pytest.skip(f\"{renderer_name} doesn't support tool calling\")\n    if renderer_name in _RENDERERS_WITH_DIFFERENT_SUPERVISED_GEN_HEADERS:\n        pytest.skip(\n            f\"{renderer_name} has different headers for supervised (</think>) vs generation (<think>)\"\n        )\n    assert len(messages) >= 2, \"Need at least 2 messages for this test\"\n    assert messages[-1][\"role\"] == \"assistant\", \"Last message must be assistant\"\n\n    prefix_messages = messages[:-1]\n    final_message = messages[-1]\n\n    # Build supervised example\n    from tinker_cookbook.renderers import TrainOnWhat\n\n    model_input, weights = renderer.build_supervised_example(\n        messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_MESSAGE\n    )\n    sup_tokens = model_input.to_ints()\n    weights_list = weights.tolist()\n\n    # Split into observation and action\n    ob, ac = _split_by_weights(sup_tokens, weights_list)\n\n    # Build generation prompt for prefix\n    gen_prompt = renderer.build_generation_prompt(prefix_messages)\n    gen_tokens = gen_prompt.to_ints()\n\n    # Check 1: Observation should match generation prompt\n    ob_matches_gen = ob == gen_tokens\n    if not ob_matches_gen:\n        # Find where they diverge\n        min_len = min(len(ob), len(gen_tokens))\n        diverge_idx = min_len\n        for i in range(min_len):\n            if ob[i] != gen_tokens[i]:\n                diverge_idx = i\n                break\n\n        ob_decoded = tokenizer.decode(ob)\n        gen_decoded = tokenizer.decode(gen_tokens)\n\n        # Show the discrepancy\n        raise AssertionError(\n            f\"Observation tokens do not match generation prompt for {renderer_name}.\\n\"\n            f\"Divergence at token {diverge_idx}:\\n\"\n            f\"  ob[{diverge_idx}:]:  {ob[diverge_idx : diverge_idx + 10]} = {tokenizer.decode(ob[diverge_idx : diverge_idx + 10])!r}\\n\"\n            f\"  gen[{diverge_idx}:]: {gen_tokens[diverge_idx : diverge_idx + 10]} = {tokenizer.decode(gen_tokens[diverge_idx : diverge_idx + 10])!r}\\n\"\n            f\"\\nFull observation ({len(ob)} tokens):\\n{ob_decoded!r}\\n\"\n            f\"\\nFull generation prompt ({len(gen_tokens)} tokens):\\n{gen_decoded!r}\"\n        )\n\n    # Check 2: Parse the action tokens\n    parsed_message, parse_success = renderer.parse_response(ac)\n\n    # Check parse success\n    assert parse_success, (\n        f\"Failed to parse action tokens for {renderer_name}.\\n\"\n        f\"Action tokens: {ac}\\n\"\n        f\"Decoded: {tokenizer.decode(ac)!r}\\n\"\n        f\"Parsed message: {parsed_message}\"\n    )\n\n    # Check 3: Parsed content should match final message content\n    # Normalize both to list form for comparison (handles string vs list[TextPart])\n    parsed_normalized = ensure_list(parsed_message[\"content\"])\n    expected_normalized = ensure_list(final_message[\"content\"])\n    assert parsed_normalized == expected_normalized, (\n        f\"Parsed content does not match final message for {renderer_name}.\\n\"\n        f\"Expected: {expected_normalized!r}\\n\"\n        f\"Got: {parsed_normalized!r}\"\n    )\n\n\n# =============================================================================\n# EOT Parsing Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_name\",\n    [\n        (\"Qwen/Qwen3-30B-A3B\", \"qwen3\"),\n        (\"Qwen/Qwen3-8B\", \"qwen3_disable_thinking\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5_disable_thinking\"),\n        (\"meta-llama/Llama-3.2-1B-Instruct\", \"llama3\"),\n        # deepseekv3 defaults to non-thinking, deepseekv3_thinking is thinking mode\n        (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3\"),\n        (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3_thinking\"),\n        (\"openai/gpt-oss-20b\", \"gpt_oss_medium_reasoning\"),\n        (\"moonshotai/Kimi-K2-Thinking\", \"kimi_k2\"),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3\"),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3_disable_thinking\"),\n    ],\n)\ndef test_eot_parsing(model_name: str, renderer_name: str):\n    \"\"\"Test EOT token parsing behavior for different renderers using real tokenizers.\"\"\"\n    skip_if_deepseek_tokenizer_bug(model_name)\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    # Get the appropriate EOT token for each renderer\n    # Note: DeepSeek uses full-width pipes (｜) not ASCII pipes (|)\n    eot_tokens = {\n        \"llama3\": \"<|eot_id|>\",\n        \"qwen3\": \"<|im_end|>\",\n        \"qwen3_disable_thinking\": \"<|im_end|>\",\n        \"qwen3_5\": \"<|im_end|>\",\n        \"qwen3_5_disable_thinking\": \"<|im_end|>\",\n        \"deepseekv3\": \"<｜end▁of▁sentence｜>\",  # Full-width pipes\n        \"deepseekv3_thinking\": \"<｜end▁of▁sentence｜>\",  # Full-width pipes\n        \"deepseekv3_disable_thinking\": \"<｜end▁of▁sentence｜>\",  # Full-width pipes (alias)\n        \"gpt_oss_medium_reasoning\": \"<|return|>\",\n        \"kimi_k2\": \"<|im_end|>\",\n        \"nemotron3\": \"<|im_end|>\",\n        \"nemotron3_disable_thinking\": \"<|im_end|>\",\n    }\n    eot_token = eot_tokens.get(renderer_name)\n    if eot_token is None:\n        raise ValueError(f\"Unknown renderer: {renderer_name}\")\n\n    # Test case 1: Normal case with single EOT - should parse correctly\n    test_response_with_eot = f\"53 + 18 = 71{eot_token}\"\n    response_tokens = tokenizer.encode(test_response_with_eot, add_special_tokens=False)\n\n    message, format_correct = renderer.parse_response(response_tokens)\n    assert message[\"role\"] == \"assistant\"\n    assert message[\"content\"] == \"53 + 18 = 71\"\n    assert format_correct is True\n\n    # Test case 2: No EOT token - should have format=False\n    test_response_no_eot = \"53 + 18 = 71\"\n    response_tokens_no_eot = tokenizer.encode(test_response_no_eot, add_special_tokens=False)\n\n    message, format_correct = renderer.parse_response(response_tokens_no_eot)\n    assert message[\"role\"] == \"assistant\"\n    assert message[\"content\"] == \"53 + 18 = 71\"\n    assert format_correct is False\n\n    # Test case 3: Double EOT token - should raise ValueError\n    test_response_double_eot = f\"53 + 18 = 71{eot_token}{eot_token}\"\n    response_tokens_double_eot = tokenizer.encode(\n        test_response_double_eot, add_special_tokens=False\n    )\n\n    with pytest.raises(ValueError, match=r\"expected .* 1\"):\n        _ = renderer.parse_response(response_tokens_double_eot)\n\n\n# =============================================================================\n# No User Messages Edge Case Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_name\",\n    [\n        (\"meta-llama/Llama-3.2-1B-Instruct\", \"llama3\"),\n        (\"Qwen/Qwen3-8B\", \"qwen3\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5_disable_thinking\"),\n        (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3\"),\n    ],\n)\ndef test_supervised_example_no_user_messages(model_name: str, renderer_name: str):\n    \"\"\"Test that build_supervised_example doesn't crash when there are no user messages.\n\n    Regression test: previously, `max(idx for ... if role == 'user')` raised ValueError\n    on an empty sequence when no user messages were present. Now uses `default=-1`.\n    With LAST_ASSISTANT_MESSAGE and no user messages, all assistant tokens should be trained on\n    (since every message is \"after the last user\").\n    \"\"\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    messages: list[Message] = [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"assistant\", \"content\": \"Hello! How can I help you?\"},\n    ]\n\n    # Should not raise ValueError\n    model_input, weights = renderer.build_supervised_example(\n        messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_MESSAGE\n    )\n    tokens = model_input.to_ints()\n\n    assert len(tokens) > 0, \"Should produce non-empty token sequence\"\n    assert len(weights) == len(tokens), \"Weights should match token count\"\n\n    # With no user messages, the assistant message is \"after the last user\" (last_user_idx == -1),\n    # so the assistant tokens should have weight=1\n    assert any(w > 0 for w in weights.tolist()), (\n        \"At least some tokens should have non-zero weight for LAST_ASSISTANT_MESSAGE with no user messages\"\n    )\n\n\n# =============================================================================\n# Sequence Extension Property Tests\n# =============================================================================\n\n\ndef _verify_extension_property(renderer, messages: list[Message], tokenizer):\n    \"\"\"\n    Verify the sequence extension property for multi-turn conversations.\n\n    The extension property holds when the full sequence at timestep t (observation + action)\n    is a prefix of the observation at timestep t+1. This enables KV-cache reuse and O(T)\n    compute scaling for T-turn trajectories.\n\n    For a conversation [user1, asst1, user2, asst2, ...], we check:\n    - (prompt_before_asst1 + asst1_completion) is prefix of prompt_before_asst2\n    - (prompt_before_asst2 + asst2_completion) is prefix of prompt_before_asst3\n    - etc.\n\n    The \"completion\" for an assistant message is how it would be rendered as the model's\n    output (with thinking, tool calls, etc.), not how it appears in history (where thinking\n    might be stripped).\n    \"\"\"\n    # Find all assistant message indices\n    assistant_indices = [i for i, m in enumerate(messages) if m[\"role\"] == \"assistant\"]\n\n    if len(assistant_indices) < 2:\n        return  # Need at least 2 assistant messages to test extension\n\n    # Build sequences for comparison\n    # seq[i] = observation before assistant i + completion for assistant i\n    # We check if seq[i] is a prefix of observation before assistant i+1\n    for i in range(len(assistant_indices) - 1):\n        asst_idx = assistant_indices[i]\n        next_asst_idx = assistant_indices[i + 1]\n\n        # Build the assistant's completion - we need to render the assistant message\n        # as it would appear when generated (with thinking preserved), not as it\n        # would appear in history. We do this by building a supervised example and\n        # extracting the tokens after the prompt.\n        messages_through_asst = messages[: asst_idx + 1]\n        model_input_through_asst, _ = renderer.build_supervised_example(messages_through_asst)\n        seq_through_asst = model_input_through_asst.to_ints()\n\n        # Build prompt before the next assistant message (observation_{t+1})\n        context_before_next = messages[:next_asst_idx]\n        prompt_before_next = renderer.build_generation_prompt(context_before_next).to_ints()\n\n        # Check if seq_through_asst is a prefix of prompt_before_next\n        is_prefix = prompt_before_next[: len(seq_through_asst)] == seq_through_asst\n        if not is_prefix:\n            # Decode for debugging\n            seq_str = tokenizer.decode(seq_through_asst)\n            next_prompt_str = tokenizer.decode(prompt_before_next)\n            # Find where they diverge\n            diverge_idx = 0\n            for j in range(min(len(seq_through_asst), len(prompt_before_next))):\n                if seq_through_asst[j] != prompt_before_next[j]:\n                    diverge_idx = j\n                    break\n            else:\n                diverge_idx = min(len(seq_through_asst), len(prompt_before_next))\n\n            raise AssertionError(\n                f\"Extension property violated between assistant {i} and {i + 1}.\\n\"\n                f\"Full sequence through asst {i} (len={len(seq_through_asst)}) is NOT a prefix \"\n                f\"of prompt before asst {i + 1} (len={len(prompt_before_next)}).\\n\"\n                f\"Divergence at token {diverge_idx}:\\n\"\n                f\"  Seq through asst[{diverge_idx}:]: {seq_through_asst[diverge_idx : diverge_idx + 10]}\\n\"\n                f\"  Next prompt[{diverge_idx}:]:      {prompt_before_next[diverge_idx : diverge_idx + 10]}\\n\"\n                f\"Sequence through assistant: {seq_str}\\n\"\n                f\"Next prompt: {next_prompt_str}\"\n            )\n\n\n# Test extension property actually holds for renderers that claim it\n# Format: (model_name, renderer_name_or_class, renderer_kwargs, conversation_fn)\n# If renderer_name_or_class is a string, use get_renderer; if a class, instantiate directly\n_EXTENSION_PROPERTY_TEST_PARAMS = [\n    # Llama3 with basic multi-turn (tool calling not supported - see llama3.py docstring)\n    (\"meta-llama/Llama-3.2-1B-Instruct\", \"llama3\", {}, get_basic_4turn_conversation),\n    # RoleColon with basic multi-turn (doesn't support tools)\n    (\"meta-llama/Llama-3.2-1B-Instruct\", \"role_colon\", {}, get_basic_4turn_conversation),\n    # Qwen3 Instruct with basic multi-turn\n    (\"Qwen/Qwen3-8B\", \"qwen3_instruct\", {}, get_basic_4turn_conversation),\n    # Qwen3 Instruct with tool calls\n    (\"Qwen/Qwen3-8B\", \"qwen3_instruct\", {}, get_multiturn_tool_conversation),\n    # Qwen3 with strip_thinking_from_history=False (preserves thinking)\n    (\n        \"Qwen/Qwen3-8B\",\n        Qwen3Renderer,\n        {\"strip_thinking_from_history\": False},\n        get_multiturn_thinking_conversation,\n    ),\n    # Qwen3.5 with strip_thinking_from_history=False (preserves thinking)\n    (\n        \"Qwen/Qwen3.5-35B-A3B\",\n        Qwen3_5Renderer,\n        {\"strip_thinking_from_history\": False},\n        get_multiturn_thinking_conversation,\n    ),\n    # Qwen3.5 disable thinking with strip_thinking_from_history=False (preserves thinking)\n    (\n        \"Qwen/Qwen3.5-35B-A3B\",\n        Qwen3_5DisableThinkingRenderer,\n        {\"strip_thinking_from_history\": False},\n        get_multiturn_thinking_conversation,\n    ),\n    # DeepSeek non-thinking with basic multi-turn\n    (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3\", {}, get_basic_4turn_conversation),\n    # DeepSeek non-thinking with tool calls\n    (\"deepseek-ai/DeepSeek-V3.1\", \"deepseekv3\", {}, get_multiturn_tool_conversation),\n    # DeepSeek with strip_thinking_from_history=False (preserves thinking)\n    (\n        \"deepseek-ai/DeepSeek-V3.1\",\n        DeepSeekV3ThinkingRenderer,\n        {\"strip_thinking_from_history\": False},\n        get_multiturn_thinking_conversation,\n    ),\n    # DeepSeek with strip_thinking_from_history=False + tool calls\n    (\n        \"deepseek-ai/DeepSeek-V3.1\",\n        DeepSeekV3ThinkingRenderer,\n        {\"strip_thinking_from_history\": False},\n        get_multiturn_thinking_and_tool_conversation,\n    ),\n]\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_name_or_class,renderer_kwargs,conversation_fn\",\n    _EXTENSION_PROPERTY_TEST_PARAMS,\n)\ndef test_extension_property_holds(\n    model_name, renderer_name_or_class, renderer_kwargs, conversation_fn\n):\n    \"\"\"\n    Test that renderers with has_extension_property=True actually satisfy the property.\n    For each conversation, verify that build_generation_prompt at successive assistant\n    turns produces token sequences where each is a prefix of the next.\n    \"\"\"\n    tokenizer = get_tokenizer(model_name)\n\n    if isinstance(renderer_name_or_class, str):\n        renderer = get_renderer(renderer_name_or_class, tokenizer)\n    else:\n        renderer = renderer_name_or_class(tokenizer, **renderer_kwargs)\n\n    assert renderer.has_extension_property, (\n        f\"Expected {renderer_name_or_class} to have has_extension_property=True\"\n    )\n\n    messages = conversation_fn()\n    _verify_extension_property(renderer, messages, tokenizer)\n\n\ndef test_extension_property_breaks_when_expected():\n    \"\"\"\n    Verify that extension property actually breaks for renderers that strip thinking.\n    This confirms our test helper can detect violations.\n    \"\"\"\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-8B\")\n    renderer = Qwen3Renderer(tokenizer, strip_thinking_from_history=True)\n\n    assert not renderer.has_extension_property, \"Default Qwen3Renderer should NOT have extension\"\n\n    messages = get_multiturn_thinking_conversation()\n\n    # Extension should break - expect an assertion error\n    with pytest.raises(AssertionError, match=\"Extension property violated\"):\n        _verify_extension_property(renderer, messages, tokenizer)\n\n\n@pytest.fixture\ndef cleanup_custom_renderer():\n    \"\"\"Fixture to ensure custom renderers are cleaned up after tests.\"\"\"\n    registered_names: list[str] = []\n    yield registered_names\n    # Cleanup: unregister any renderers that were registered during the test\n    for name in registered_names:\n        unregister_renderer(name)\n\n\ndef test_register_and_get_custom_renderer(cleanup_custom_renderer):\n    \"\"\"Test that a custom renderer can be registered and retrieved via get_renderer.\"\"\"\n    custom_name = \"_test_custom_renderer_abc123\"\n    cleanup_custom_renderer.append(custom_name)\n\n    # Should not be registered initially\n    assert not is_renderer_registered(custom_name)\n\n    # Create a simple factory that returns a Qwen3Renderer\n    def custom_factory(tokenizer, image_processor=None):\n        return Qwen3Renderer(tokenizer)\n\n    # Register the custom renderer\n    register_renderer(custom_name, custom_factory)\n\n    # Should now be registered\n    assert is_renderer_registered(custom_name)\n    names = get_registered_renderer_names()\n    assert custom_name in names\n\n    # Verify it can be retrieved\n    tokenizer = get_tokenizer(\"Qwen/Qwen3-8B\")\n    renderer = get_renderer(custom_name, tokenizer)\n\n    assert isinstance(renderer, Qwen3Renderer)\n\n    unregister_renderer(custom_name)\n\n    with pytest.raises(ValueError, match=\"Unknown renderer\"):\n        renderer = get_renderer(custom_name, tokenizer)\n\n\n@pytest.fixture\ndef cleanup_custom_tokenizer():\n    \"\"\"Fixture to ensure custom tokenizers are cleaned up after tests.\"\"\"\n    registered_names: list[str] = []\n    yield registered_names\n    # Cleanup: unregister any tokenizers that were registered during the test\n    for name in registered_names:\n        unregister_tokenizer(name)\n\n\ndef test_register_and_get_custom_tokenizer(cleanup_custom_tokenizer):\n    \"\"\"Test that a custom tokenizer can be registered and retrieved via get_tokenizer.\"\"\"\n    custom_name = \"_test_custom_tokenizer_abc123\"\n    cleanup_custom_tokenizer.append(custom_name)\n\n    # Should not be registered initially\n    assert not is_tokenizer_registered(custom_name)\n\n    # Create a simple factory that returns an existing tokenizer\n    real_tokenizer = get_tokenizer(\"Qwen/Qwen3-8B\")\n\n    def custom_factory():\n        return real_tokenizer\n\n    # Register the custom tokenizer\n    register_tokenizer(custom_name, custom_factory)\n\n    # Should now be registered\n    assert is_tokenizer_registered(custom_name)\n    names = get_registered_tokenizer_names()\n    assert custom_name in names\n\n    # Verify it can be retrieved\n    tokenizer = get_tokenizer(custom_name)\n    assert tokenizer is real_tokenizer\n\n    # Unregister and verify it falls back to HF (which will fail for fake name)\n    unregister_tokenizer(custom_name)\n    assert not is_tokenizer_registered(custom_name)\n"
  },
  {
    "path": "tinker_cookbook/renderers/role_colon.py",
    "content": "\"\"\"Simple role:content format renderer.\"\"\"\n\nimport tinker\n\nfrom tinker_cookbook.renderers.base import (\n    Message,\n    RenderContext,\n    RenderedMessage,\n    Renderer,\n    ToolSpec,\n    ensure_text,\n)\n\n\nclass RoleColonRenderer(Renderer):\n    \"\"\"Simple role:content format renderer.\n\n    Format::\n\n        User: <content>\n\n        Assistant: <content>\n\n    This is basically the format used by DeepSeek R1-Zero, and similar to the format\n    used by Anthropic, except that they use \"Human\" instead of \"User\".\n    \"\"\"\n\n    @property\n    def has_extension_property(self) -> bool:\n        \"\"\"RoleColon satisfies the extension property - no content is stripped from history.\"\"\"\n        return True\n\n    def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:\n        header_str = message[\"role\"].capitalize() + \":\"\n        output_str = \" \" + ensure_text(message[\"content\"]) + \"\\n\\n\"\n        # stop_overlap completes the stop sequence \"\\n\\nUser:\" for assistant messages.\n        # For non-assistant messages, we use a placeholder that's never actually concatenated.\n        stop_overlap_str = \"User:\" if message[\"role\"] == \"assistant\" else \"<UNUSED>\"\n        header = tinker.types.EncodedTextChunk(\n            tokens=self.tokenizer.encode(header_str, add_special_tokens=False)\n        )\n        output: list[tinker.ModelInputChunk] = [\n            tinker.types.EncodedTextChunk(\n                tokens=self.tokenizer.encode(output_str, add_special_tokens=False)\n            )\n        ]\n        stop_overlap = tinker.types.EncodedTextChunk(\n            tokens=self.tokenizer.encode(stop_overlap_str, add_special_tokens=False)\n        )\n        return RenderedMessage(header=header, output=output, stop_overlap=stop_overlap)\n\n    def get_stop_sequences(self) -> list[str]:\n        return [\"\\n\\nUser:\"]\n\n    def parse_response(self, response: list[int]) -> tuple[Message, bool]:\n        import logging\n\n        logger = logging.getLogger(__name__)\n\n        # Strip EOS token from the end if present (base models may terminate with EOS\n        # instead of the expected stop sequence). We still return False for parse success\n        # since the model didn't produce the expected stop sequence.\n        terminated_with_eos = False\n        eos_token_id = self.tokenizer.eos_token_id\n        if eos_token_id is not None and response and response[-1] == eos_token_id:\n            response = response[:-1]\n            terminated_with_eos = True\n\n        str_response = str(self.tokenizer.decode(response))\n        splitted = str_response.split(\"\\n\\nUser:\")\n        if len(splitted) == 1:\n            logger.debug(f\"Response is not a valid assistant response: {str_response}\")\n            return Message(role=\"assistant\", content=str_response.strip()), False\n        elif len(splitted) == 2:\n            before, _after = splitted\n            return Message(role=\"assistant\", content=before.strip()), not terminated_with_eos\n        else:\n            logger.warning(\n                \"RoleColonRenderer.parse_response saw multiple stop delimiters \"\n                \"(count=%d). Returning parse_success=False. Full response:\\n%s\",\n                len(splitted) - 1,\n                str_response,\n            )\n            return Message(role=\"assistant\", content=splitted[0].strip()), False\n\n    @property\n    def _bos_tokens(self) -> list[int]:\n        bos_token_str = self.tokenizer.bos_token\n        if bos_token_str is None:\n            return []\n        assert isinstance(bos_token_str, str)\n        return self.tokenizer.encode(bos_token_str, add_special_tokens=False)\n\n    def create_conversation_prefix_with_tools(\n        self, tools: list[ToolSpec], system_prompt: str = \"\"\n    ) -> list[Message]:\n        raise NotImplementedError(\"RoleColonRenderer does not support tool calling\")\n"
  },
  {
    "path": "tinker_cookbook/renderers/testing_utils.py",
    "content": "\"\"\"Shared test utilities for renderer tests.\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nimport pytest\nimport transformers\n\n\ndef extract_token_ids(result: Any) -> list[int]:\n    \"\"\"Extract token IDs from apply_chat_template result.\n\n    transformers 4.x returns list[int], while 5.x returns BatchEncoding (dict-like\n    with 'input_ids' and 'attention_mask' keys). This helper normalizes both to list[int].\n    \"\"\"\n    if hasattr(result, \"input_ids\"):\n        return list(result[\"input_ids\"])\n    return list(result)\n\n\n_DEEPSEEK_TOKENIZER_BUG = (\n    \"transformers 5.3.0 has a known bug with DeepSeek tokenizer that strips spaces during decode. \"\n    \"See https://github.com/huggingface/transformers/pull/44801\"\n)\n\n_HAS_DEEPSEEK_TOKENIZER_BUG = transformers.__version__ == \"5.3.0\"\n\nskip_deepseek_tokenizer_bug = pytest.mark.skipif(\n    _HAS_DEEPSEEK_TOKENIZER_BUG,\n    reason=_DEEPSEEK_TOKENIZER_BUG,\n)\n\n\ndef skip_if_deepseek_tokenizer_bug(model_name: str) -> None:\n    \"\"\"Skip the current test if running DeepSeek on transformers 5.3.0.\"\"\"\n    if _HAS_DEEPSEEK_TOKENIZER_BUG and \"deepseek\" in model_name.lower():\n        pytest.skip(_DEEPSEEK_TOKENIZER_BUG)\n"
  },
  {
    "path": "tinker_cookbook/renderers/tool_calling_test.py",
    "content": "\"\"\"\nTests for tool calling support in renderers.\n\nThese tests verify that renderers correctly handle:\n1. Tool response message rendering (role mapping and content wrapping)\n2. Parsing of single and multiple tool calls from model output\n3. Stripping tool call blocks from parsed message content\n4. Extracting function names from model-specific tool call formats\n\"\"\"\n\nimport pytest\nimport tinker\n\nfrom tinker_cookbook.renderers import Message, RenderContext, get_renderer, get_text_content\nfrom tinker_cookbook.renderers.testing_utils import skip_deepseek_tokenizer_bug\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n# =============================================================================\n# Tool Response Rendering Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_name\",\n    [\n        (\"Qwen/Qwen3-8B\", \"qwen3\"),\n        (\"Qwen/Qwen3-30B-A3B-Instruct-2507\", \"qwen3_instruct\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5_disable_thinking\"),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3\"),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3_disable_thinking\"),\n    ],\n)\ndef test_qwen3_tool_response_rendering(model_name: str, renderer_name: str):\n    \"\"\"Test that Qwen3 renders tool responses with user role and tool_response tags.\n\n    Per the Qwen3 chat template, tool messages should render as\n    <|im_start|>user with content wrapped in <tool_response> tags.\n    \"\"\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    tool_message: Message = {\"role\": \"tool\", \"content\": '{\"weather\": \"sunny\", \"high\": 72}'}\n\n    ctx = RenderContext(idx=0, is_last=False, prev_message=None)\n    rendered = renderer.render_message(tool_message, ctx)\n    header = rendered.header\n    assert header is not None, \"Expected header in rendered message\"\n    output = rendered.output\n    assert len(output) > 0, \"Expected output in rendered message\"\n\n    header_str = tokenizer.decode(list(header.tokens))\n    # output[0] is an EncodedTextChunk for text-only messages\n    output_chunk = output[0]\n    assert isinstance(output_chunk, tinker.EncodedTextChunk), \"Expected EncodedTextChunk\"\n    output_str = tokenizer.decode(list(output_chunk.tokens))\n\n    # Tool messages should be rendered as \"user\" role\n    assert \"<|im_start|>user\" in header_str\n    # Content should be wrapped in tool_response tags\n    assert \"<tool_response>\" in output_str\n    assert \"</tool_response>\" in output_str\n    assert '\"weather\": \"sunny\"' in output_str\n\n\n# =============================================================================\n# Tool Call Parsing Tests\n# =============================================================================\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_name\",\n    [\n        (\"Qwen/Qwen3-8B\", \"qwen3\"),\n        (\"Qwen/Qwen3-30B-A3B-Instruct-2507\", \"qwen3_instruct\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5_disable_thinking\"),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3\"),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3_disable_thinking\"),\n    ],\n)\ndef test_qwen3_parse_single_tool_call(model_name: str, renderer_name: str):\n    \"\"\"Test parsing a single tool call from Qwen3 response.\"\"\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    # Simulate model response with tool call\n    response_text = \"\"\"I'll search for that information.\n<tool_call>\n{\"name\": \"search\", \"arguments\": {\"query\": \"weather in NYC\"}}\n</tool_call><|im_end|>\"\"\"\n    if renderer_name.startswith(\"qwen3_5\") or renderer_name.startswith(\"nemotron3\"):\n        response_text = \"\"\"I'll search for that information.\n<tool_call>\n<function=search>\n<parameter=query>\nweather in NYC\n</parameter>\n</function>\n</tool_call><|im_end|>\"\"\"\n\n    response_tokens = tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success is True\n    assert message[\"role\"] == \"assistant\"\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 1\n    assert message[\"tool_calls\"][0].function.name == \"search\"\n    # Content should have tool_call block stripped (text content only)\n    text_content = get_text_content(message)\n    assert \"<tool_call>\" not in text_content\n\n\n@pytest.mark.parametrize(\n    \"model_name,renderer_name\",\n    [\n        (\"Qwen/Qwen3-8B\", \"qwen3\"),\n        (\"Qwen/Qwen3.5-35B-A3B\", \"qwen3_5\"),\n        (\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\", \"nemotron3\"),\n    ],\n)\ndef test_qwen3_parse_multiple_tool_calls(model_name: str, renderer_name: str):\n    \"\"\"Test parsing multiple tool calls from Qwen3 response.\n\n    When a model response contains multiple <tool_call> blocks, all should be parsed.\n    \"\"\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    # Simulate model response with multiple tool calls\n    response_text = \"\"\"I'll get the weather for both cities.\n<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"NYC\"}}\n</tool_call>\n<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"LA\"}}\n</tool_call><|im_end|>\"\"\"\n    if renderer_name in (\"qwen3_5\", \"nemotron3\"):\n        response_text = \"\"\"I'll get the weather for both cities.\n<tool_call>\n<function=get_weather>\n<parameter=location>\nNYC\n</parameter>\n</function>\n</tool_call>\n<tool_call>\n<function=get_weather>\n<parameter=location>\nLA\n</parameter>\n</function>\n</tool_call><|im_end|>\"\"\"\n\n    response_tokens = tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success is True\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 2\n    assert message[\"tool_calls\"][0].function.name == \"get_weather\"\n    assert message[\"tool_calls\"][1].function.name == \"get_weather\"\n    # Verify different arguments\n    assert \"NYC\" in message[\"tool_calls\"][0].function.arguments\n    assert \"LA\" in message[\"tool_calls\"][1].function.arguments\n\n\ndef test_kimi_k2_parse_tool_call():\n    \"\"\"Test parsing tool call from Kimi K2 response.\n\n    Kimi K2 uses tool_id format \"functions.{name}:{idx}\", and the function\n    name should be extracted correctly.\n    \"\"\"\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    # Simulate model response with tool call (Kimi K2 format)\n    response_text = \"\"\"<think></think>I'll search for that.\n<|tool_calls_section_begin|><|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{\"query\": \"weather NYC\"}<|tool_call_end|><|tool_calls_section_end|><|im_end|>\"\"\"\n\n    response_tokens = tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success is True\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 1\n    # Verify function name is extracted from tool_id\n    assert message[\"tool_calls\"][0].function.name == \"search\"\n    assert message[\"tool_calls\"][0].id == \"functions.search:0\"\n\n\n@skip_deepseek_tokenizer_bug\ndef test_deepseek_parse_tool_call():\n    \"\"\"Test parsing tool call from DeepSeek V3 response.\n\n    DeepSeek V3 HF template format: <｜tool▁call▁begin｜>name<｜tool▁sep｜>args<｜tool▁call▁end｜>\n    \"\"\"\n    model_name = \"deepseek-ai/DeepSeek-V3.1\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"deepseekv3\", tokenizer)\n\n    response_text = \"\"\"I'll check the weather.\n<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>get_weather<｜tool▁sep｜>{\"location\": \"NYC\"}<｜tool▁call▁end｜><｜tool▁calls▁end｜><｜end▁of▁sentence｜>\"\"\"\n\n    response_tokens = tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success is True\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 1\n    assert message[\"tool_calls\"][0].function.name == \"get_weather\"\n    assert \"NYC\" in message[\"tool_calls\"][0].function.arguments\n    # Content should have tool calls section stripped (text content only)\n    text_content = get_text_content(message)\n    assert \"<｜tool▁calls▁begin｜>\" not in text_content\n\n\n# =============================================================================\n# Edge Cases and Error Handling\n# =============================================================================\n\n\ndef test_qwen3_parse_invalid_tool_call_json():\n    \"\"\"Test that invalid JSON in tool call is captured as unparsed_tool_calls.\"\"\"\n    model_name = \"Qwen/Qwen3-8B\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"qwen3\", tokenizer)\n\n    # Invalid JSON in tool call\n    response_text = \"\"\"<tool_call>\n{invalid json here}\n</tool_call><|im_end|>\"\"\"\n\n    response_tokens = tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = renderer.parse_response(response_tokens)\n\n    # Parse succeeds, but tool call is captured as unparsed\n    assert success is True\n    assert \"tool_calls\" not in message or len(message.get(\"tool_calls\", [])) == 0\n    assert \"unparsed_tool_calls\" in message\n    assert len(message[\"unparsed_tool_calls\"]) == 1\n    assert \"Invalid JSON\" in message[\"unparsed_tool_calls\"][0].error\n    # Raw text should contain the original tool call\n    assert \"<tool_call>\" in message[\"unparsed_tool_calls\"][0].raw_text\n\n\ndef test_qwen3_mixed_valid_invalid_tool_calls():\n    \"\"\"Test parsing when some tool calls are valid and some are invalid.\n\n    Valid tool calls should be parsed, invalid ones captured in unparsed_tool_calls.\n    \"\"\"\n    model_name = \"Qwen/Qwen3-8B\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"qwen3\", tokenizer)\n\n    # First tool call is valid, second has invalid JSON\n    response_text = \"\"\"I'll try both.\n<tool_call>\n{\"name\": \"search\", \"arguments\": {\"query\": \"weather\"}}\n</tool_call>\n<tool_call>\n{bad json here}\n</tool_call><|im_end|>\"\"\"\n\n    response_tokens = tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success is True\n    # Valid tool call should be parsed\n    assert \"tool_calls\" in message\n    assert len(message[\"tool_calls\"]) == 1\n    assert message[\"tool_calls\"][0].function.name == \"search\"\n    # Invalid tool call should be in unparsed_tool_calls\n    assert \"unparsed_tool_calls\" in message\n    assert len(message[\"unparsed_tool_calls\"]) == 1\n    assert \"Invalid JSON\" in message[\"unparsed_tool_calls\"][0].error\n\n\n@skip_deepseek_tokenizer_bug\ndef test_deepseek_parse_invalid_tool_call_json():\n    \"\"\"Test that invalid JSON in DeepSeek tool call is captured as unparsed.\"\"\"\n    model_name = \"deepseek-ai/DeepSeek-V3.1\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"deepseekv3\", tokenizer)\n\n    response_text = \"\"\"I'll check.\n<｜tool▁calls▁begin｜><｜tool▁call▁begin｜>get_weather<｜tool▁sep｜>{invalid json}<｜tool▁call▁end｜><｜tool▁calls▁end｜><｜end▁of▁sentence｜>\"\"\"\n\n    response_tokens = tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success is True\n    assert \"tool_calls\" not in message or len(message.get(\"tool_calls\", [])) == 0\n    assert \"unparsed_tool_calls\" in message\n    assert len(message[\"unparsed_tool_calls\"]) == 1\n    assert \"Invalid JSON\" in message[\"unparsed_tool_calls\"][0].error\n\n\ndef test_kimi_k2_parse_invalid_tool_call_json():\n    \"\"\"Test that invalid JSON in Kimi K2 tool call is captured as unparsed.\"\"\"\n    model_name = \"moonshotai/Kimi-K2-Thinking\"\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(\"kimi_k2\", tokenizer)\n\n    response_text = \"\"\"<think></think>I'll search.\n<|tool_calls_section_begin|><|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{invalid}<|tool_call_end|><|tool_calls_section_end|><|im_end|>\"\"\"\n\n    response_tokens = tokenizer.encode(response_text, add_special_tokens=False)\n    message, success = renderer.parse_response(response_tokens)\n\n    assert success is True\n    assert \"tool_calls\" not in message or len(message.get(\"tool_calls\", [])) == 0\n    assert \"unparsed_tool_calls\" in message\n    assert len(message[\"unparsed_tool_calls\"]) == 1\n    assert \"Invalid JSON\" in message[\"unparsed_tool_calls\"][0].error\n"
  },
  {
    "path": "tinker_cookbook/rl/__init__.py",
    "content": ""
  },
  {
    "path": "tinker_cookbook/rl/builder_pickle_test.py",
    "content": "\"\"\"Tests for picklability of RL EnvGroupBuilders and rollout executor infrastructure.\"\"\"\n\nimport pickle\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import partial\n\nfrom tinker_cookbook.renderers import Message, get_renderer\nfrom tinker_cookbook.rl.problem_env import ProblemGroupBuilder\nfrom tinker_cookbook.rl.rollouts import (\n    _RolloutTask,\n    get_rollout_executor,\n    set_rollout_executor,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n\nclass TestProblemGroupBuilderPickle:\n    def test_pickle_roundtrip(self) -> None:\n        \"\"\"ProblemGroupBuilder with a Renderer-bound env_thunk survives pickle.\n\n        Uses the real MathEnv class, matching how recipes actually construct builders.\n        \"\"\"\n        from tinker_cookbook.recipes.math_rl.math_env import MathEnv\n\n        tokenizer = get_tokenizer(\"meta-llama/Llama-3.1-8B-Instruct\")\n        renderer = get_renderer(\"llama3\", tokenizer)\n\n        builder = ProblemGroupBuilder(\n            env_thunk=partial(MathEnv, \"What is 2+2?\", \"4\", renderer),\n            num_envs=4,\n            dataset_name=\"test_math\",\n        )\n\n        restored = pickle.loads(pickle.dumps(builder))\n\n        assert restored.num_envs == 4\n        assert restored.dataset_name == \"test_math\"\n        # Verify the renderer inside the partial survived\n        assert restored.env_thunk.args[2]._renderer_name == \"llama3\"\n\n    def test_pickle_with_convo_prefix(self) -> None:\n        \"\"\"ProblemGroupBuilder with convo_prefix in the partial survives pickle.\"\"\"\n        from tinker_cookbook.recipes.math_rl.math_env import MathEnv\n\n        tokenizer = get_tokenizer(\"meta-llama/Llama-3.1-8B-Instruct\")\n        renderer = get_renderer(\"llama3\", tokenizer)\n        convo_prefix: list[Message] = [{\"role\": \"system\", \"content\": \"You are helpful.\"}]\n\n        builder = ProblemGroupBuilder(\n            env_thunk=partial(MathEnv, \"What is 2+2?\", \"4\", renderer, convo_prefix=convo_prefix),\n            num_envs=2,\n        )\n\n        restored = pickle.loads(pickle.dumps(builder))\n        assert restored.env_thunk.keywords[\"convo_prefix\"] == convo_prefix\n\n\nclass TestRolloutTask:\n    def test_pickle_roundtrip(self) -> None:\n        \"\"\"_RolloutTask survives pickle roundtrip with a real Renderer-bound builder.\"\"\"\n        from tinker_cookbook.recipes.math_rl.math_env import MathEnv\n\n        tokenizer = get_tokenizer(\"meta-llama/Llama-3.1-8B-Instruct\")\n        renderer = get_renderer(\"llama3\", tokenizer)\n\n        builder = ProblemGroupBuilder(\n            env_thunk=partial(MathEnv, \"What is 2+2?\", \"4\", renderer),\n            num_envs=2,\n        )\n\n        # SamplingClient can't be constructed without a server, so test with None\n        # to verify the dataclass + builder pickle. Full integration requires a server.\n        task = _RolloutTask(\n            sampling_client=None,  # type: ignore[arg-type]\n            env_group_builder=builder,\n            max_tokens=256,\n            temperature=1.0,\n            remove_constant_reward_groups=False,\n            enable_logging=False,\n        )\n\n        restored = pickle.loads(pickle.dumps(task))\n        assert restored.max_tokens == 256\n        assert restored.temperature == 1.0\n        assert restored.remove_constant_reward_groups is False\n        assert restored.env_group_builder.num_envs == 2\n        assert restored.env_group_builder.env_thunk.args[2]._renderer_name == \"llama3\"\n\n\nclass TestRolloutExecutorContextVar:\n    def test_default_is_none(self) -> None:\n        \"\"\"Default rollout executor is None (in-process async).\"\"\"\n        assert get_rollout_executor() is None\n\n    def test_set_and_get(self) -> None:\n        \"\"\"set_rollout_executor / get_rollout_executor roundtrip.\"\"\"\n        executor = ThreadPoolExecutor(max_workers=1)\n        try:\n            set_rollout_executor(executor)\n            assert get_rollout_executor() is executor\n        finally:\n            set_rollout_executor(None)\n            executor.shutdown(wait=False)\n        assert get_rollout_executor() is None\n"
  },
  {
    "path": "tinker_cookbook/rl/data_processing.py",
    "content": "\"\"\"\nData processing functions for RL training.\n\nContains functions for computing advantages, converting trajectories to training data,\nand assembling training batches.\n\"\"\"\n\nimport logging\n\nimport tinker\nimport torch\nfrom tinker import TensorData\n\nfrom tinker_cookbook.rl.types import Trajectory, TrajectoryGroup\nfrom tinker_cookbook.supervised.common import (\n    create_rightshifted_model_input_and_leftshifted_targets,\n)\nfrom tinker_cookbook.utils.misc_utils import all_same, safezip\n\nlogger = logging.getLogger(__name__)\n\n\ndef compute_advantages(trajectory_groups_P: list[TrajectoryGroup]) -> list[torch.Tensor]:\n    \"\"\"Compute advantages for each trajectory, centered within groups.\"\"\"\n    advantages_P: list[torch.Tensor] = []\n\n    for traj_group in trajectory_groups_P:\n        rewards_G = torch.tensor(traj_group.get_total_rewards())\n        # Center advantages within the group\n        advantages_G = rewards_G - rewards_G.mean()\n        advantages_P.append(advantages_G)\n\n    return advantages_P\n\n\nFlatObElem = int | tinker.ModelInputChunk\nFlatOb = list[FlatObElem]\n\n\ndef _is_prefix(seq1: FlatOb, seq2: FlatOb) -> bool:\n    \"\"\"\n    Check if seq1 is a prefix of seq2.\n    \"\"\"\n    return len(seq1) <= len(seq2) and seq2[: len(seq1)] == seq1\n\n\ndef _flat_ob_token_len(flat_ob: FlatOb) -> int:\n    out = 0\n    for elem in flat_ob:\n        if isinstance(elem, int):\n            out += 1\n        else:\n            out += elem.length\n    return out\n\n\ndef _flat_ob_to_model_input(flat_ob: FlatOb) -> tinker.ModelInput:\n    out: list[tinker.ModelInputChunk] = []\n    current_text_chunk: list[int] = []\n\n    def flush_text_chunk():\n        if current_text_chunk:\n            out.append(tinker.EncodedTextChunk(tokens=current_text_chunk))\n            current_text_chunk.clear()\n\n    for elem in flat_ob:\n        if isinstance(elem, int):\n            current_text_chunk.append(elem)\n        else:\n            flush_text_chunk()\n            out.append(elem)\n    flush_text_chunk()\n    return tinker.ModelInput(chunks=out)\n\n\ndef _flatten_chunks(chunks: list[tinker.ModelInputChunk]) -> FlatOb:\n    out: FlatOb = []\n    for chunk in chunks:\n        if isinstance(chunk, tinker.EncodedTextChunk):\n            out.extend(chunk.tokens)\n        else:\n            out.append(chunk)\n    return out\n\n\ndef trajectory_to_data(traj: Trajectory, traj_advantage: float) -> list[tinker.Datum]:\n    \"\"\"\n    Return one or more Datum objects corresponding to the trajectory.\n    If the sequence grows by appending, i.e., each successive observation contains\n    the previous observation+action as a prefix, then we can return a single Datum.\n    However, if we get a sequence that's not an extension of the previous sequence,\n    then that results in a new Datum.\n\n    For example, let O1 denote a chunk of observation tokens, and let A1 denote an action.\n\n    Then let's say ob_ac_pairs is as follows.\n\n    (O1, A1)\n    (O1+A1+O2, A2)\n    (O3, A3)\n\n    Then we will merge the first two observation-action pairs into a single Datum,\n    and the last observation-action pair into a separate Datum.\n    \"\"\"\n\n    class SequenceAccumulator:\n        full_sequence: list[FlatObElem] = []\n        sampled_logprobs: list[float] = []\n        advantages: list[float] = []\n        mask: list[float] = []\n\n        @classmethod\n        def clear(cls):\n            cls.full_sequence = []\n            cls.sampled_logprobs = []\n            cls.advantages = []\n            cls.mask = []\n\n    def make_datum_from_state():\n        all_tokens_T = _flat_ob_to_model_input(SequenceAccumulator.full_sequence)\n        input_tokens_T, target_tokens_T = create_rightshifted_model_input_and_leftshifted_targets(\n            list(all_tokens_T.chunks)\n        )\n        sampled_logprobs_T = SequenceAccumulator.sampled_logprobs[1:]\n        advantages_T = SequenceAccumulator.advantages[1:]\n        mask_T = SequenceAccumulator.mask[1:]\n        assert (\n            input_tokens_T.length\n            == len(target_tokens_T)\n            == len(sampled_logprobs_T)\n            == len(advantages_T)\n            == len(mask_T)\n        )\n        return tinker.Datum(\n            model_input=input_tokens_T,\n            loss_fn_inputs={\n                \"target_tokens\": TensorData.from_torch(torch.tensor(target_tokens_T)),\n                \"logprobs\": TensorData.from_torch(torch.tensor(sampled_logprobs_T)),\n                \"advantages\": TensorData.from_torch(torch.tensor(advantages_T)),\n                \"mask\": TensorData.from_torch(torch.tensor(mask_T)),\n            },\n        )\n\n    data: list[tinker.Datum] = []\n    for transition in traj.transitions:\n        ob = transition.ob\n        ob_flat = _flatten_chunks(ob.chunks)\n        ac_with_logprobs = transition.ac\n        if len(SequenceAccumulator.full_sequence) == 0:\n            delta_ob_flat = ob_flat\n        elif _is_prefix(SequenceAccumulator.full_sequence, ob_flat):\n            delta_ob_flat = ob_flat[len(SequenceAccumulator.full_sequence) :]\n        else:\n            data.append(make_datum_from_state())\n            SequenceAccumulator.clear()\n            delta_ob_flat = ob_flat\n        delta_ob_len = _flat_ob_token_len(delta_ob_flat)\n        SequenceAccumulator.full_sequence.extend(delta_ob_flat)\n        SequenceAccumulator.full_sequence.extend(ac_with_logprobs.tokens)\n        SequenceAccumulator.sampled_logprobs.extend(\n            [0.0] * delta_ob_len + ac_with_logprobs.logprobs\n        )\n        SequenceAccumulator.advantages.extend(\n            [0] * delta_ob_len + [traj_advantage] * len(ac_with_logprobs.tokens)\n        )\n        SequenceAccumulator.mask.extend([0.0] * delta_ob_len + [1.0] * len(ac_with_logprobs.tokens))\n\n    if SequenceAccumulator.full_sequence:\n        data.append(make_datum_from_state())\n\n    return data\n\n\ndef assemble_training_data(\n    trajectory_groups_P: list[TrajectoryGroup],\n    advantages_P: list[torch.Tensor],\n) -> tuple[list[tinker.Datum], list[dict[str, int]]]:\n    \"\"\"Convert trajectories to training data format.\"\"\"\n    data_D: list[tinker.Datum] = []\n    metadata_D: list[dict[str, int]] = []\n\n    for i_group, (traj_group, advantages_G) in enumerate(\n        safezip(trajectory_groups_P, advantages_P)\n    ):\n        for i_traj, (traj, traj_advantage) in enumerate(\n            safezip(traj_group.trajectories_G, advantages_G)\n        ):\n            # Build the full sequence from the trajectory\n            new_data = trajectory_to_data(traj, float(traj_advantage))\n            data_D.extend(new_data)\n            metadata_D.extend([{\"group_idx\": i_group, \"traj_idx\": i_traj} for _ in new_data])\n\n    return data_D, metadata_D\n\n\ndef remove_constant_reward_groups(\n    trajectory_groups_P: list[TrajectoryGroup],\n) -> list[TrajectoryGroup]:\n    new_groups: list[TrajectoryGroup] = []\n    for group in trajectory_groups_P:\n        if not all_same(group.get_total_rewards()):\n            new_groups.append(group)\n    if not new_groups:\n        logger.warning(\"All rewards are uniform. There will be no gradient\")\n        return trajectory_groups_P[0:1]  # return singleton list in case empty\n        # list will cause problems\n    return new_groups\n"
  },
  {
    "path": "tinker_cookbook/rl/message_env.py",
    "content": "\"\"\"Message-level environment abstraction.\n\nMessageEnv operates at the message level (list[Message]) rather than token level.\n\nEnvFromMessageEnv bridges MessageEnv to the token-level Env interface used by\nthe RL training loop.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, field\n\nimport tinker\n\nfrom tinker_cookbook.completers import StopCondition\nfrom tinker_cookbook.renderers import Renderer\nfrom tinker_cookbook.renderers.base import Message\nfrom tinker_cookbook.rl import types\n\n\n@dataclass\nclass MessageStepResult:\n    \"\"\"Result of a message-level step.\"\"\"\n\n    reward: float\n    episode_done: bool\n    next_messages: list[Message]\n    metrics: dict[str, float] = field(default_factory=dict)\n    logs: types.Logs = field(default_factory=dict)\n    next_stop_condition: StopCondition | None = None\n\n\nclass MessageEnv(ABC):\n    \"\"\"Abstract base class for message-level environments.\"\"\"\n\n    @abstractmethod\n    async def initial_observation(self) -> list[Message]:\n        \"\"\"Return the initial conversation history as renderer messages.\"\"\"\n        ...\n\n    @abstractmethod\n    async def step(self, message: Message) -> MessageStepResult:\n        \"\"\"Process an assistant message and return reward/next state.\"\"\"\n        ...\n\n\nclass EnvFromMessageEnv(types.Env):\n    \"\"\"Adapter that wraps a MessageEnv to implement the token-level Env interface.\n\n    This bridges the message-level abstraction to the token-level interface\n    expected by the RL training loop.\n    \"\"\"\n\n    def __init__(\n        self,\n        renderer: Renderer,\n        message_env: MessageEnv,\n        failed_parse_reward: float = -1.0,\n        terminate_on_parse_error: bool = True,\n        max_trajectory_tokens: int | None = None,\n    ):\n        self.renderer = renderer\n        self.message_env = message_env\n        self.failed_parse_reward = failed_parse_reward\n        self.terminate_on_parse_error = terminate_on_parse_error\n        self.max_trajectory_tokens = max_trajectory_tokens\n        self._base_stop_condition = renderer.get_stop_sequences()\n\n    async def _render_in_thread(self, messages: list[Message], **kwargs) -> tinker.ModelInput:\n        \"\"\"Run build_generation_prompt in a thread to avoid blocking the event loop.\n\n        Tokenization is CPU-bound. With many concurrent tasks on the same event\n        loop, running it synchronously starves other coroutines. HuggingFace\n        tokenizers release the GIL, so threads give true parallelism.\n        \"\"\"\n        return await asyncio.to_thread(self.renderer.build_generation_prompt, messages, **kwargs)\n\n    async def initial_observation(self) -> tuple[tinker.ModelInput, StopCondition]:\n        messages = await self.message_env.initial_observation()\n        model_input = await self._render_in_thread(messages)\n        return model_input, self._base_stop_condition\n\n    async def step(self, action: types.Action) -> types.StepResult:\n        \"\"\"Parse tokens to a message, delegate to MessageEnv, and render response.\"\"\"\n        assistant_message, parse_success = self.renderer.parse_response(action)\n\n        if not parse_success:\n            return types.StepResult(\n                reward=self.failed_parse_reward,\n                episode_done=self.terminate_on_parse_error,\n                next_observation=tinker.ModelInput.empty(),\n                next_stop_condition=self._base_stop_condition,\n                metrics={\"parse_error\": 1.0},\n            )\n\n        msg_step = await self.message_env.step(assistant_message)\n        next_observation = await self._render_in_thread(msg_step.next_messages)\n        next_stop_condition = msg_step.next_stop_condition or self._base_stop_condition\n\n        # Check if trajectory exceeds max token limit\n        if (\n            self.max_trajectory_tokens is not None\n            and next_observation.length > self.max_trajectory_tokens\n        ):\n            return types.StepResult(\n                reward=0.0,\n                episode_done=True,\n                next_observation=tinker.ModelInput.empty(),\n                next_stop_condition=self._base_stop_condition,\n                metrics={**msg_step.metrics, \"context_overflow\": 1.0},\n                logs=msg_step.logs,\n            )\n\n        return types.StepResult(\n            reward=msg_step.reward,\n            episode_done=msg_step.episode_done,\n            next_observation=next_observation,\n            next_stop_condition=next_stop_condition,\n            metrics=msg_step.metrics,\n            logs=msg_step.logs,\n        )\n"
  },
  {
    "path": "tinker_cookbook/rl/message_env_test.py",
    "content": "\"\"\"Tests for EnvFromMessageEnv (tinker_cookbook/rl/message_env.py).\n\nVerifies that EnvFromMessageEnv correctly bridges message-level environments\nto the token-level Env interface, including:\n- Threading: build_generation_prompt runs via asyncio.to_thread\n- Parse success/failure handling\n- Max trajectory token enforcement\n- Stop condition propagation\n\"\"\"\n\nimport asyncio\nfrom unittest.mock import MagicMock, patch\n\nimport tinker\n\nfrom tinker_cookbook.renderers.base import Message\nfrom tinker_cookbook.rl.message_env import EnvFromMessageEnv, MessageEnv, MessageStepResult\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _make_model_input(tokens: list[int]) -> tinker.ModelInput:\n    return tinker.ModelInput.from_ints(tokens)\n\n\nclass StubMessageEnv(MessageEnv):\n    \"\"\"Minimal MessageEnv for testing.\"\"\"\n\n    def __init__(\n        self,\n        initial_messages: list[Message],\n        step_result: MessageStepResult,\n    ):\n        self._initial_messages = initial_messages\n        self._step_result = step_result\n        self.step_calls: list[Message] = []\n\n    async def initial_observation(self) -> list[Message]:\n        return self._initial_messages\n\n    async def step(self, message: Message) -> MessageStepResult:\n        self.step_calls.append(message)\n        return self._step_result\n\n\ndef _make_renderer(\n    gen_prompt_tokens: list[int] | None = None,\n    stop_sequences: list[str] | None = None,\n    parse_message: Message | None = None,\n    parse_success: bool = True,\n) -> MagicMock:\n    \"\"\"Build a mock Renderer with the methods EnvFromMessageEnv calls.\"\"\"\n    renderer = MagicMock()\n\n    prompt = _make_model_input(gen_prompt_tokens or [1, 2, 3])\n    renderer.build_generation_prompt = MagicMock(return_value=prompt)\n    renderer.get_stop_sequences = MagicMock(return_value=stop_sequences or [\"<stop>\"])\n    renderer.parse_response = MagicMock(\n        return_value=(\n            parse_message or {\"role\": \"assistant\", \"content\": \"hello\"},\n            parse_success,\n        )\n    )\n    return renderer\n\n\n# ---------------------------------------------------------------------------\n# Tests\n# ---------------------------------------------------------------------------\n\n\nclass TestInitialObservation:\n    def test_returns_rendered_prompt_and_stop_condition(self):\n        \"\"\"initial_observation should render messages and return base stop condition.\"\"\"\n        renderer = _make_renderer(gen_prompt_tokens=[10, 20, 30], stop_sequences=[\"<eos>\"])\n        initial_msgs: list[Message] = [{\"role\": \"user\", \"content\": \"hi\"}]\n        msg_env = StubMessageEnv(\n            initial_messages=initial_msgs,\n            step_result=MessageStepResult(reward=0, episode_done=False, next_messages=[]),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        model_input, stop_cond = asyncio.run(env.initial_observation())\n\n        assert model_input.to_ints() == [10, 20, 30]\n        assert stop_cond == [\"<eos>\"]\n        renderer.build_generation_prompt.assert_called_once_with(initial_msgs)\n\n    def test_render_runs_in_thread(self):\n        \"\"\"build_generation_prompt should be dispatched via asyncio.to_thread.\"\"\"\n        renderer = _make_renderer()\n        msg_env = StubMessageEnv(\n            initial_messages=[{\"role\": \"user\", \"content\": \"hi\"}],\n            step_result=MessageStepResult(reward=0, episode_done=False, next_messages=[]),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        with patch(\n            \"tinker_cookbook.rl.message_env.asyncio.to_thread\", wraps=asyncio.to_thread\n        ) as mock_to_thread:\n            asyncio.run(env.initial_observation())\n            mock_to_thread.assert_called_once()\n            # First positional arg should be the renderer method\n            assert mock_to_thread.call_args[0][0] is renderer.build_generation_prompt\n\n\nclass TestStepParseFailure:\n    def test_parse_failure_returns_failed_reward(self):\n        \"\"\"When parse_response fails, step returns failed_parse_reward.\"\"\"\n        renderer = _make_renderer(parse_success=False)\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(reward=1.0, episode_done=False, next_messages=[]),\n        )\n        env = EnvFromMessageEnv(\n            renderer=renderer,\n            message_env=msg_env,\n            failed_parse_reward=-2.0,\n            terminate_on_parse_error=True,\n        )\n\n        result = asyncio.run(env.step([1, 2, 3]))\n\n        assert result.reward == -2.0\n        assert result.episode_done is True\n        assert result.metrics == {\"parse_error\": 1.0}\n        assert result.next_observation.length == 0\n        # MessageEnv.step should NOT have been called\n        assert len(msg_env.step_calls) == 0\n\n    def test_parse_failure_no_terminate(self):\n        \"\"\"When terminate_on_parse_error=False, episode continues after parse failure.\"\"\"\n        renderer = _make_renderer(parse_success=False)\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(reward=1.0, episode_done=False, next_messages=[]),\n        )\n        env = EnvFromMessageEnv(\n            renderer=renderer,\n            message_env=msg_env,\n            failed_parse_reward=-1.0,\n            terminate_on_parse_error=False,\n        )\n\n        result = asyncio.run(env.step([1, 2, 3]))\n\n        assert result.episode_done is False\n        assert result.reward == -1.0\n\n\nclass TestStepSuccess:\n    def test_delegates_to_message_env_and_renders(self):\n        \"\"\"On successful parse, step delegates to MessageEnv and renders next messages.\"\"\"\n        assistant_msg: Message = {\"role\": \"assistant\", \"content\": \"answer\"}\n        next_msgs: list[Message] = [\n            {\"role\": \"user\", \"content\": \"hi\"},\n            {\"role\": \"assistant\", \"content\": \"answer\"},\n            {\"role\": \"user\", \"content\": \"followup\"},\n        ]\n        renderer = _make_renderer(\n            gen_prompt_tokens=[10, 20, 30, 40],\n            stop_sequences=[\"<stop>\"],\n            parse_message=assistant_msg,\n        )\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.75,\n                episode_done=False,\n                next_messages=next_msgs,\n                metrics={\"custom\": 1.0},\n            ),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        result = asyncio.run(env.step([5, 6, 7]))\n\n        # Should have delegated parsed message to MessageEnv\n        assert len(msg_env.step_calls) == 1\n        assert msg_env.step_calls[0] == assistant_msg\n\n        assert result.reward == 0.75\n        assert result.episode_done is False\n        assert result.next_observation.to_ints() == [10, 20, 30, 40]\n        assert result.metrics == {\"custom\": 1.0}\n        assert result.next_stop_condition == [\"<stop>\"]\n\n    def test_custom_stop_condition_from_message_env(self):\n        \"\"\"When MessageEnv returns a next_stop_condition, it overrides the base one.\"\"\"\n        renderer = _make_renderer(stop_sequences=[\"<base_stop>\"])\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.5,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n                next_stop_condition=[\"<custom_stop>\"],\n            ),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.next_stop_condition == [\"<custom_stop>\"]\n\n    def test_none_stop_condition_falls_back_to_base(self):\n        \"\"\"When MessageEnv returns None for next_stop_condition, base is used.\"\"\"\n        renderer = _make_renderer(stop_sequences=[\"<base>\"])\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.5,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n                next_stop_condition=None,\n            ),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.next_stop_condition == [\"<base>\"]\n\n\nclass TestMaxTrajectoryTokens:\n    def test_context_overflow_terminates_episode(self):\n        \"\"\"When next_observation exceeds max_trajectory_tokens, episode ends.\"\"\"\n        # Renderer returns a 100-token observation\n        renderer = _make_renderer(gen_prompt_tokens=list(range(100)), stop_sequences=[\"<s>\"])\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.9,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n                metrics={\"turns\": 5.0},\n            ),\n        )\n        env = EnvFromMessageEnv(\n            renderer=renderer,\n            message_env=msg_env,\n            max_trajectory_tokens=50,  # limit is 50, observation is 100\n        )\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.episode_done is True\n        assert result.reward == 0.0\n        assert result.next_observation.length == 0  # empty observation\n        assert result.metrics[\"context_overflow\"] == 1.0\n        # Original metrics should be preserved\n        assert result.metrics[\"turns\"] == 5.0\n\n    def test_within_limit_continues(self):\n        \"\"\"When next_observation is within max_trajectory_tokens, episode continues.\"\"\"\n        renderer = _make_renderer(gen_prompt_tokens=[1, 2, 3], stop_sequences=[\"<s>\"])\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.5,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n            ),\n        )\n        env = EnvFromMessageEnv(\n            renderer=renderer,\n            message_env=msg_env,\n            max_trajectory_tokens=1000,  # plenty of room\n        )\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.episode_done is False\n        assert result.reward == 0.5\n        assert \"context_overflow\" not in result.metrics\n\n    def test_no_limit_set(self):\n        \"\"\"When max_trajectory_tokens is None, no overflow check occurs.\"\"\"\n        renderer = _make_renderer(gen_prompt_tokens=list(range(10000)), stop_sequences=[\"<s>\"])\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=1.0,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n            ),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.episode_done is False\n        assert \"context_overflow\" not in result.metrics\n\n\nclass TestStepThreading:\n    def test_step_renders_in_thread(self):\n        \"\"\"On successful parse, the next observation rendering should use to_thread.\"\"\"\n        renderer = _make_renderer()\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.5,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n            ),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        with patch(\n            \"tinker_cookbook.rl.message_env.asyncio.to_thread\", wraps=asyncio.to_thread\n        ) as mock_to_thread:\n            asyncio.run(env.step([1, 2]))\n            mock_to_thread.assert_called_once()\n            assert mock_to_thread.call_args[0][0] is renderer.build_generation_prompt\n\n\nclass TestLogsPassthrough:\n    \"\"\"MessageStepResult.logs should be forwarded to StepResult.logs.\"\"\"\n\n    def test_logs_forwarded_on_success(self):\n        \"\"\"Logs from MessageEnv are passed through on normal step.\"\"\"\n        renderer = _make_renderer(gen_prompt_tokens=[1, 2, 3])\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.5,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n                logs={\"assistant\": \"hello world\", \"tool_call_0\": \"name=search\"},\n            ),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.logs == {\"assistant\": \"hello world\", \"tool_call_0\": \"name=search\"}\n\n    def test_logs_forwarded_on_context_overflow(self):\n        \"\"\"Logs from MessageEnv are preserved even when context overflows.\"\"\"\n        renderer = _make_renderer(gen_prompt_tokens=list(range(100)))\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.5,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n                logs={\"assistant\": \"some response\", \"tool_result_0\": \"result data\"},\n            ),\n        )\n        env = EnvFromMessageEnv(\n            renderer=renderer,\n            message_env=msg_env,\n            max_trajectory_tokens=50,\n        )\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.episode_done is True\n        assert result.metrics[\"context_overflow\"] == 1.0\n        assert result.logs == {\"assistant\": \"some response\", \"tool_result_0\": \"result data\"}\n\n    def test_no_logs_on_parse_error(self):\n        \"\"\"Parse errors bypass MessageEnv, so logs are empty.\"\"\"\n        renderer = _make_renderer(parse_success=False)\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=1.0,\n                episode_done=False,\n                next_messages=[],\n                logs={\"should_not\": \"appear\"},\n            ),\n        )\n        env = EnvFromMessageEnv(\n            renderer=renderer, message_env=msg_env, terminate_on_parse_error=True\n        )\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.logs == {}\n\n    def test_empty_logs_by_default(self):\n        \"\"\"When MessageEnv doesn't set logs, StepResult.logs defaults to empty.\"\"\"\n        renderer = _make_renderer(gen_prompt_tokens=[1, 2])\n        msg_env = StubMessageEnv(\n            initial_messages=[],\n            step_result=MessageStepResult(\n                reward=0.5,\n                episode_done=False,\n                next_messages=[{\"role\": \"user\", \"content\": \"x\"}],\n            ),\n        )\n        env = EnvFromMessageEnv(renderer=renderer, message_env=msg_env)\n\n        result = asyncio.run(env.step([1]))\n\n        assert result.logs == {}\n"
  },
  {
    "path": "tinker_cookbook/rl/metric_util.py",
    "content": "import asyncio\nimport itertools\nimport logging\nfrom collections import defaultdict\n\nimport numpy as np\nimport tinker\n\nfrom tinker_cookbook.completers import TinkerTokenCompleter, TokenCompleter\nfrom tinker_cookbook.eval.evaluators import SamplingClientEvaluator\nfrom tinker_cookbook.exceptions import AllTrajectoriesFailedError\nfrom tinker_cookbook.rl.rollout_logging import (\n    RolloutSummaryExportConfig,\n    write_rollout_summaries_jsonl,\n)\nfrom tinker_cookbook.rl.rollout_strategy import RolloutStrategy\nfrom tinker_cookbook.rl.rollouts import (\n    RolloutErrorCounter,\n    do_group_rollout,\n    do_group_rollout_and_filter_constant_reward,\n    get_rollout_executor,\n)\nfrom tinker_cookbook.rl.types import EnvGroupBuilder, RLDataset, TrajectoryGroup\nfrom tinker_cookbook.utils import logtree\nfrom tinker_cookbook.utils.misc_utils import all_same, dict_mean\n\nlogger = logging.getLogger(__name__)\n\n\ndef _compute_by_group_metrics(trajectory_groups_P: list[TrajectoryGroup], good_thresh: float = 0.5):\n    n_groups = len(trajectory_groups_P)\n    n_mixed = n_good = n_bad = 0\n    for tg in trajectory_groups_P:\n        grp_rewards = tg.get_total_rewards()\n        if all_same(grp_rewards):\n            if grp_rewards[0] >= good_thresh:\n                n_good += 1\n            else:\n                n_bad += 1\n        else:\n            n_mixed += 1\n    return {\n        \"by_group/frac_mixed\": n_mixed / n_groups,\n        \"by_group/frac_all_good\": n_good / n_groups,\n        \"by_group/frac_all_bad\": n_bad / n_groups,\n    }\n\n\ndef compute_trajectory_metrics(\n    trajectory_groups_P: list[TrajectoryGroup], taglist_P: list[list[str]]\n) -> dict[str, float]:\n    tag2trajgroups = defaultdict(list)\n    for taglist, trajectory_group in zip(taglist_P, trajectory_groups_P):\n        for tag in taglist:\n            tag2trajgroups[tag].append(trajectory_group)\n    out = {}\n    have_nontrivial_tags = any(\n        len(trajgroups) < len(trajectory_groups_P) for trajgroups in tag2trajgroups.values()\n    )  # check if any tag gives us a strict subset of the full trajectory groups\n    if have_nontrivial_tags:\n        for tag, trajectory_groups in tag2trajgroups.items():\n            prefixed_metrics = {\n                f\"env/{tag}/{k}\": v\n                for k, v in _compute_trajectory_metrics(trajectory_groups).items()\n            }\n            out.update(prefixed_metrics)\n    out.update(\n        {f\"env/all/{k}\": v for k, v in _compute_trajectory_metrics(trajectory_groups_P).items()}\n    )\n    return out\n\n\ndef _compute_trajectory_metrics(trajectory_groups_P: list[TrajectoryGroup]) -> dict[str, float]:\n    \"\"\"Compute metrics for the trajectory groups.\"\"\"\n    flat_trajs_PG = [traj for tg in trajectory_groups_P for traj in tg.trajectories_G]\n    ac_tokens_by_turn = [\n        len(transition.ac.tokens) for traj in flat_trajs_PG for transition in traj.transitions\n    ]\n    ob_tokens_by_turn = [\n        transition.ob.length for traj in flat_trajs_PG for transition in traj.transitions\n    ]\n    turns_by_trajectory = [len(traj.transitions) for traj in flat_trajs_PG]\n    # Compute metrics\n    metrics = {\n        \"ac_tokens_per_turn\": sum(ac_tokens_by_turn) / sum(turns_by_trajectory),\n        \"ob_tokens_per_turn\": sum(ob_tokens_by_turn) / sum(turns_by_trajectory),\n        \"turns_per_episode\": sum(turns_by_trajectory) / len(flat_trajs_PG),\n        \"total_episodes\": len(flat_trajs_PG),\n        \"total_turns\": sum(turns_by_trajectory),\n        \"total_ac_tokens\": sum(ac_tokens_by_turn),\n        \"total_ob_tokens\": sum(ob_tokens_by_turn),\n    }\n    metrics[\"reward/total\"] = np.mean(\n        [reward for tg in trajectory_groups_P for reward in tg.get_total_rewards()]\n    ).item()\n    # Per-transition metrics\n    transition_metrics = [\n        transition.metrics\n        for tg in trajectory_groups_P\n        for traj in tg.trajectories_G\n        for transition in traj.transitions\n    ]\n    traj_metrics = [metrics for tg in trajectory_groups_P for metrics in tg.metrics_G]\n    metrics.update(dict_mean(transition_metrics + traj_metrics))\n    # combine traj_metrics and transition_metrics in case there's some key\n    # (like format error) that appears in the per-step metrics for some envs\n    # but the compute_group_rewards metric for other envs.\n    metrics.update(_compute_by_group_metrics(trajectory_groups_P))\n    return metrics\n\n\ndef dataset_to_env_group_builders(dataset: RLDataset) -> list[EnvGroupBuilder]:\n    \"\"\"\n    Get the whole dataset as a list of env group builders.\n    \"\"\"\n    return list(itertools.chain(*[dataset.get_batch(i) for i in range(len(dataset))]))\n\n\nclass RLTestSetEvaluator(SamplingClientEvaluator):\n    def __init__(\n        self,\n        dataset: RLDataset,\n        max_tokens: int,\n        name: str = \"test\",\n        num_groups_to_log: int = 4,\n        strategy: RolloutStrategy | None = None,\n    ):\n        self.env_group_builders_P = dataset_to_env_group_builders(dataset)\n        self.max_tokens = max_tokens\n        self.name = name\n        self.num_groups_to_log = num_groups_to_log\n        self.strategy = strategy\n\n    async def eval_token_completer(\n        self,\n        policy: TokenCompleter,\n        *,\n        rollout_summary_export: RolloutSummaryExportConfig | None = None,\n    ) -> dict[str, float]:\n        async def run_group_rollout(\n            builder: EnvGroupBuilder, group_idx: int\n        ) -> TrajectoryGroup | None:\n            enable_logging = group_idx < self.num_groups_to_log\n            try:\n                with logtree.optional_enable_logging(enable=enable_logging):\n                    result = await do_group_rollout(\n                        builder,\n                        policy,\n                        strategy=self.strategy,\n                    )\n            except AllTrajectoriesFailedError as e:\n                logger.warning(f\"Eval: {e}\")\n                result = None\n            except Exception as e:\n                if self.strategy is None or not self.strategy.catches_group_errors:\n                    raise\n                logger.warning(f\"Eval rollout error ({type(e).__name__}): {e}\")\n                result = None\n            return result\n\n        results = await asyncio.gather(\n            *[\n                run_group_rollout(builder, group_idx)\n                for group_idx, builder in enumerate(self.env_group_builders_P)\n            ]\n        )\n        return self._collect_eval_metrics(results, rollout_summary_export)\n\n    async def __call__(\n        self,\n        sampling_client: tinker.SamplingClient,\n        *,\n        rollout_summary_export: RolloutSummaryExportConfig | None = None,\n    ) -> dict[str, float]:\n        if get_rollout_executor() is not None:\n            # Use the executor-aware dispatch path so rollouts are offloaded\n            return await self._eval_with_executor(\n                sampling_client, rollout_summary_export=rollout_summary_export\n            )\n\n        policy = TinkerTokenCompleter(sampling_client, max_tokens=self.max_tokens)\n        return await self.eval_token_completer(\n            policy,\n            rollout_summary_export=rollout_summary_export,\n        )\n\n    async def _eval_with_executor(\n        self,\n        sampling_client: tinker.SamplingClient,\n        *,\n        rollout_summary_export: RolloutSummaryExportConfig | None = None,\n    ) -> dict[str, float]:\n        \"\"\"Run evaluation with rollouts dispatched via the rollout executor.\"\"\"\n        results = await asyncio.gather(\n            *[\n                do_group_rollout_and_filter_constant_reward(\n                    sampling_client,\n                    builder,\n                    max_tokens=self.max_tokens,\n                    temperature=1.0,\n                    do_remove_constant_reward_groups=False,\n                    enable_logging=i < self.num_groups_to_log,\n                    strategy=self.strategy,\n                )\n                for i, builder in enumerate(self.env_group_builders_P)\n            ]\n        )\n        return self._collect_eval_metrics(results, rollout_summary_export)\n\n    def _collect_eval_metrics(\n        self,\n        results: list[TrajectoryGroup | None],\n        rollout_summary_export: RolloutSummaryExportConfig | None,\n    ) -> dict[str, float]:\n        \"\"\"Shared logic for collecting metrics from eval rollout results.\"\"\"\n        error_counter = RolloutErrorCounter()\n        for result in results:\n            error_counter.ingest(result)\n\n        trajectory_groups_P = [r for r in results if r is not None]\n        taglist_P = [\n            builder.logging_tags()\n            for builder, r in zip(self.env_group_builders_P, results)\n            if r is not None\n        ]\n        if rollout_summary_export is not None:\n            sampling_client_steps_P = (\n                [rollout_summary_export.sampling_client_step] * len(trajectory_groups_P)\n                if rollout_summary_export.sampling_client_step is not None\n                else None\n            )\n            write_rollout_summaries_jsonl(\n                rollout_summary_export.path,\n                split=rollout_summary_export.split,\n                iteration=rollout_summary_export.iteration,\n                trajectory_groups_P=trajectory_groups_P,\n                taglist_P=taglist_P,\n                sampling_client_steps_P=sampling_client_steps_P,\n            )\n        metrics = compute_trajectory_metrics(trajectory_groups_P, taglist_P)\n        metrics.update(error_counter.get_metrics())\n        metrics = {f\"{self.name}/{k}\": v for k, v in metrics.items()}\n        return metrics\n"
  },
  {
    "path": "tinker_cookbook/rl/metrics.py",
    "content": "\"\"\"\nMetrics and KL computation functions for RL training.\n\nContains functions for computing KL divergences, incorporating KL penalties,\nand computing training metrics.\n\"\"\"\n\nimport asyncio\nfrom typing import Any, cast\n\nimport tinker\nimport torch\n\nfrom tinker_cookbook.utils import trace\nfrom tinker_cookbook.utils.misc_utils import safezip\n\n\ndef compute_kl_sample_train(\n    data_D: list[tinker.Datum], training_logprobs_D: list[torch.Tensor]\n) -> dict[str, float]:\n    \"\"\"Compute KL divergence metrics between sampling and training logprobs.\"\"\"\n    all_diffs: list[torch.Tensor] = []\n    all_sampling_logprobs: list[torch.Tensor] = []\n\n    for datum, training_logprobs in safezip(data_D, training_logprobs_D):\n        # Get logprobs from sampling\n        sampling_logprobs = datum.loss_fn_inputs[\"logprobs\"].to_torch()\n        action_mask = datum.loss_fn_inputs[\"mask\"].to_torch() > 0\n        # Extract only action token logprobs\n        sampling_logprobs_actions = sampling_logprobs[action_mask]\n        training_logprobs_actions = training_logprobs[action_mask]\n\n        if len(sampling_logprobs_actions) > 0:\n            logprob_diff = sampling_logprobs_actions - training_logprobs_actions\n            all_diffs.append(logprob_diff)\n            all_sampling_logprobs.append(sampling_logprobs_actions)\n\n    assert all_diffs\n    flat_diffs = torch.cat(all_diffs)\n    kl_sample_train_v1 = flat_diffs.mean().item()\n    kl_sample_train_v2 = 0.5 * (flat_diffs**2).mean().item()\n\n    flat_sampling_logprobs = torch.cat(all_sampling_logprobs)\n    entropy_sample = -flat_sampling_logprobs.mean().item()\n    return {\n        \"optim/kl_sample_train_v1\": kl_sample_train_v1,\n        \"optim/kl_sample_train_v2\": kl_sample_train_v2,\n        \"optim/entropy\": entropy_sample,\n    }\n\n\n@trace.scope\nasync def compute_post_kl(\n    data_D: list[tinker.Datum], post_sampling_client: tinker.SamplingClient\n) -> dict[str, float]:\n    \"\"\"Compute post-update KL divergence metrics.\"\"\"\n    # Compute logprobs at all data items\n    # This is a bit ugly, but we first reconstruct the original sequence from before we did the\n    # shifting to get the inputs and targets.\n    full_sequence_inputs_D = [\n        datum.model_input.append_int(cast(int, datum.loss_fn_inputs[\"target_tokens\"].data[-1]))\n        for datum in data_D\n    ]\n    new_logprobs_D = await asyncio.gather(\n        *[\n            post_sampling_client.compute_logprobs_async(sequence_input)\n            for sequence_input in full_sequence_inputs_D\n        ]\n    )\n\n    prev_logprobs_D = [datum.loss_fn_inputs[\"logprobs\"].to_torch() for datum in data_D]\n    action_masks = [datum.loss_fn_inputs[\"mask\"].to_torch() > 0 for datum in data_D]\n    flat_diffs = [\n        (prev_logprobs - torch.tensor(new_logprobs[1:]))[action_mask]\n        for new_logprobs, prev_logprobs, action_mask in safezip(\n            new_logprobs_D, prev_logprobs_D, action_masks\n        )\n    ]\n    flat_diffs = torch.cat(flat_diffs)\n    kl_post_v1 = flat_diffs.mean().item()\n    kl_post_v2 = 0.5 * (flat_diffs**2).mean().item()\n\n    return {\"kl_pre_post_v1\": kl_post_v1, \"kl_pre_post_v2\": kl_post_v2}\n\n\n@trace.scope\nasync def incorporate_kl_penalty(\n    data_D: list[tinker.Datum],\n    base_sampling_client: tinker.SamplingClient,\n    kl_penalty_coef: float,\n    kl_discount_factor: float,\n) -> dict[str, float]:\n    \"\"\"\n    Compute KL against base model. Adjust advantages in-place by logp_base - logp_current - avg_kl,\n    where avg_kl is the average of logp_base - logp_current (which is -KL[current, base])\n    \"\"\"\n    # Compute logprobs at all data items\n    full_sequence_inputs_D = [\n        datum.model_input.append_int(cast(int, datum.loss_fn_inputs[\"target_tokens\"].data[-1]))\n        for datum in data_D\n    ]\n    base_logprobs_D = await asyncio.gather(\n        *[\n            base_sampling_client.compute_logprobs_async(sequence_input)\n            for sequence_input in full_sequence_inputs_D\n        ]\n    )\n    # compute the logprob differences, zeroed out when the mask == 0\n    sampled_logprobs_D = [datum.loss_fn_inputs[\"logprobs\"].to_torch() for datum in data_D]\n    float_masks = [datum.loss_fn_inputs[\"mask\"].to_torch().float() for datum in data_D]\n    logprob_diffs = [\n        (sampled_logprobs - torch.tensor(base_logprobs[1:])) * mask\n        for base_logprobs, sampled_logprobs, mask in safezip(\n            base_logprobs_D, sampled_logprobs_D, float_masks\n        )\n    ]\n    avg_logp_diff = sum([diff.sum() for diff in logprob_diffs]) / sum(\n        [mask.sum() for mask in float_masks]\n    )\n    for i, datum in enumerate(data_D):\n        kl_advantages = kl_penalty_coef * float_masks[i] * (avg_logp_diff - logprob_diffs[i])\n        if kl_discount_factor > 0:\n            kl_advantages = discounted_future_sum_vectorized(kl_advantages, kl_discount_factor)\n        datum.loss_fn_inputs[\"advantages\"] = tinker.TensorData.from_torch(\n            datum.loss_fn_inputs[\"advantages\"].to_torch() + kl_advantages\n        )\n\n    return {\"kl_policy_base\": float(avg_logp_diff)}\n\n\ndef discounted_future_sum_vectorized(x: torch.Tensor, gamma: float) -> torch.Tensor:\n    \"\"\"\n    Compute discounted sum of future values for each position.\n\n    For position i, computes: sum_{k=0}^{T-1-i} gamma^k * x[i+k]\n\n    Args:\n        x: 1D tensor of values.\n        gamma: Discount factor.\n\n    Returns:\n        1D tensor of discounted future sums.\n    \"\"\"\n    result = torch.empty_like(x)\n    running = torch.zeros(1, dtype=x.dtype, device=x.device)\n    for t in range(len(x) - 1, -1, -1):\n        running = x[t] + gamma * running\n        result[t] = running\n    return result\n\n\ndef compute_sampling_client_metrics(\n    wrapped_trajectory_groups: list[Any],  # WrappedTrajectoryGroup\n) -> dict[str, Any]:\n    \"\"\"Compute metrics about sampling clients used to generate trajectory groups.\"\"\"\n    sampling_client_steps = [\n        wrapped_trajectory_group.sampling_client_step\n        for wrapped_trajectory_group in wrapped_trajectory_groups\n    ]\n    sample_times = [\n        wrapped_trajectory_group.metrics[\"time/trajectory_group_worker_loop/total\"]\n        for wrapped_trajectory_group in wrapped_trajectory_groups\n    ]\n    metrics = {}\n    metrics[\"sampling_client/step_max\"] = max(sampling_client_steps)\n    metrics[\"sampling_client/step_min\"] = min(sampling_client_steps)\n    metrics[\"sampling_client/step_mean\"] = sum(sampling_client_steps) / len(sampling_client_steps)\n    metrics[\"time/sampling_time_max\"] = max(sample_times)\n    metrics[\"time/sampling_time_min\"] = min(sample_times)\n    metrics[\"time/sampling_time_mean\"] = sum(sample_times) / len(sample_times)\n    return metrics\n"
  },
  {
    "path": "tinker_cookbook/rl/multiturn_weight_assignment_test.py",
    "content": "\"\"\"Tests for multi-turn weight assignment in trajectory_to_data.\n\nVerifies that agent-generated tokens get mask=1 (trained on) and\nenvironment-provided tokens get mask=0 (masked out) when trajectories\nare converted to training data.\n\"\"\"\n\nimport asyncio\nfrom unittest.mock import MagicMock\n\nimport tinker\n\nfrom tinker_cookbook.completers import TokenCompleter, TokensWithLogprobs\nfrom tinker_cookbook.renderers.base import Message, ToolCall\nfrom tinker_cookbook.rl.data_processing import trajectory_to_data\nfrom tinker_cookbook.rl.rollouts import do_single_rollout\nfrom tinker_cookbook.rl.types import Trajectory, Transition\nfrom tinker_cookbook.tool_use import build_agent_tool_env, simple_tool_result, tool\nfrom tinker_cookbook.tool_use.types import ToolResult\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _make_transition(\n    ob_tokens: list[int],\n    ac_tokens: list[int],\n    logprobs: list[float] | None = None,\n    reward: float = 0.0,\n    done: bool = False,\n) -> Transition:\n    if logprobs is None:\n        logprobs = [0.0] * len(ac_tokens)\n    return Transition(\n        ob=tinker.ModelInput.from_ints(ob_tokens),\n        ac=TokensWithLogprobs(tokens=ac_tokens, maybe_logprobs=logprobs),\n        reward=reward,\n        episode_done=done,\n    )\n\n\ndef _get_mask(datum: tinker.Datum) -> list[float]:\n    return datum.loss_fn_inputs[\"mask\"].to_torch().tolist()\n\n\n@tool\nasync def _stub_tool() -> ToolResult:\n    \"\"\"A test tool.\"\"\"\n    return simple_tool_result(\"ok\")\n\n\nasync def _zero_reward_fn(history) -> tuple[float, dict[str, float]]:\n    return 0.0, {}\n\n\n# ---------------------------------------------------------------------------\n# Test A: Multi-turn prefix trajectory\n# ---------------------------------------------------------------------------\n\n\nclass TestMultiTurnPrefixTrajectory:\n    \"\"\"3-turn trajectory where each observation is a prefix extension.\n\n    Turn 1: ob=[1,2,3,4,5]                              ac=[10,11,12]\n    Turn 2: ob=[1,2,3,4,5, 10,11,12, 20,21]             ac=[30,31]\n    Turn 3: ob=[1,2,3,4,5, 10,11,12, 20,21, 30,31, 40]  ac=[50,51,52]\n    \"\"\"\n\n    def _make_trajectory(self) -> Trajectory:\n        return Trajectory(\n            transitions=[\n                _make_transition([1, 2, 3, 4, 5], [10, 11, 12]),\n                _make_transition([1, 2, 3, 4, 5, 10, 11, 12, 20, 21], [30, 31]),\n                _make_transition(\n                    [1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 31, 40],\n                    [50, 51, 52],\n                    done=True,\n                ),\n            ],\n            final_ob=tinker.ModelInput.from_ints([]),\n        )\n\n    def test_returns_single_datum(self):\n        data = trajectory_to_data(self._make_trajectory(), traj_advantage=1.0)\n        assert len(data) == 1\n\n    def test_mask_matches_expected(self):\n        data = trajectory_to_data(self._make_trajectory(), traj_advantage=1.0)\n        mask = _get_mask(data[0])\n        # After [1:] shift:\n        # targets: [2,3,4,5, 10,11,12, 20,21, 30,31, 40, 50,51,52]\n        # mask:    [0,0,0,0,  1, 1, 1,  0, 0,  1, 1,  0,  1, 1, 1]\n        expected = [0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1]\n        assert mask == expected\n\n    def test_mask_sum_equals_action_token_count(self):\n        data = trajectory_to_data(self._make_trajectory(), traj_advantage=1.0)\n        mask = _get_mask(data[0])\n        # 3 (ac1) + 2 (ac2) + 3 (ac3) = 8 action tokens\n        assert sum(mask) == 8\n\n\n# ---------------------------------------------------------------------------\n# Test B: Single-turn trajectory\n# ---------------------------------------------------------------------------\n\n\nclass TestSingleTurnTrajectory:\n    def test_single_turn_mask(self):\n        traj = Trajectory(\n            transitions=[_make_transition([1, 2, 3], [10, 11, 12], done=True)],\n            final_ob=tinker.ModelInput.from_ints([]),\n        )\n        data = trajectory_to_data(traj, traj_advantage=1.0)\n        assert len(data) == 1\n        mask = _get_mask(data[0])\n        # After [1:]: targets=[2,3,10,11,12], mask=[0,0,1,1,1]\n        assert mask == [0, 0, 1, 1, 1]\n        assert sum(mask) == 3\n\n\n# ---------------------------------------------------------------------------\n# Test C: Prefix break splits into multiple Datums\n# ---------------------------------------------------------------------------\n\n\nclass TestPrefixBreak:\n    def test_non_prefix_observation_produces_two_datums(self):\n        traj = Trajectory(\n            transitions=[\n                _make_transition([1, 2, 3], [10, 11]),\n                _make_transition([50, 51, 52], [60, 61], done=True),\n            ],\n            final_ob=tinker.ModelInput.from_ints([]),\n        )\n        data = trajectory_to_data(traj, traj_advantage=1.0)\n        assert len(data) == 2\n\n    def test_each_datum_has_correct_mask(self):\n        traj = Trajectory(\n            transitions=[\n                _make_transition([1, 2, 3], [10, 11]),\n                _make_transition([50, 51, 52], [60, 61], done=True),\n            ],\n            final_ob=tinker.ModelInput.from_ints([]),\n        )\n        data = trajectory_to_data(traj, traj_advantage=1.0)\n        # Datum 1: [1,2,3,10,11] → after shift: mask=[0,0,1,1]\n        assert _get_mask(data[0]) == [0, 0, 1, 1]\n        # Datum 2: [50,51,52,60,61] → after shift: mask=[0,0,1,1]\n        assert _get_mask(data[1]) == [0, 0, 1, 1]\n\n\n# ---------------------------------------------------------------------------\n# Test D: End-to-end through build_agent_tool_env → rollout → trajectory_to_data\n# ---------------------------------------------------------------------------\n\n\ndef _make_stub_renderer():\n    \"\"\"Mock renderer with deterministic tokens and extension property.\n\n    Token mapping:\n        system  → [100, 101]\n        user    → [200, 201]\n        assistant → [500] (header) + action_tokens (output)\n        tool    → [400, 401]\n        suffix  → [500] (assistant generation header)\n\n    The suffix [500] matches the assistant header, so the extension\n    property holds: build_generation_prompt([msgs]) is always a prefix\n    of build_generation_prompt([msgs, asst_msg, tool_msg, ...]).\n    \"\"\"\n    renderer = MagicMock()\n    renderer.get_stop_sequences.return_value = [\"<stop>\"]\n\n    # Side-channel to pass action tokens from _parse to _build (can't stash on\n    # Message since it's a TypedDict).\n    action_tokens_by_id: dict[int, list[int]] = {}\n\n    parse_results = iter(\n        [\n            # Call 1: model makes a tool call\n            (\n                Message(\n                    role=\"assistant\",\n                    content=\"\",\n                    tool_calls=[\n                        ToolCall(\n                            function=ToolCall.FunctionBody(name=\"_stub_tool\", arguments=\"{}\"),\n                            id=\"call_1\",\n                        )\n                    ],\n                ),\n                True,\n            ),\n            # Call 2: model gives final answer (no tool calls)\n            (Message(role=\"assistant\", content=\"done\"), True),\n        ]\n    )\n\n    def _parse(tokens):\n        msg, success = next(parse_results)\n        action_tokens_by_id[id(msg)] = list(tokens)\n        return msg, success\n\n    renderer.parse_response.side_effect = _parse\n\n    def _build(messages, role=\"assistant\", prefill=None):\n        tokens = []\n        for msg in messages:\n            r = msg[\"role\"]\n            if r == \"system\":\n                tokens.extend([100, 101])\n            elif r == \"user\":\n                tokens.extend([200, 201])\n            elif r == \"assistant\":\n                tokens.extend([500])  # assistant header\n                tokens.extend(action_tokens_by_id.get(id(msg), []))\n            elif r == \"tool\":\n                tokens.extend([400, 401])\n        tokens.extend([500])  # suffix: assistant generation header\n        return tinker.ModelInput.from_ints(tokens)\n\n    renderer.build_generation_prompt.side_effect = _build\n    return renderer\n\n\ndef _make_stub_policy():\n    \"\"\"TokenCompleter returning predetermined responses.\"\"\"\n    responses = iter(\n        [\n            TokensWithLogprobs(tokens=[300, 301], maybe_logprobs=[-0.5, -0.3]),\n            TokensWithLogprobs(tokens=[310, 311], maybe_logprobs=[-0.4, -0.2]),\n        ]\n    )\n\n    class StubPolicy(TokenCompleter):\n        async def __call__(self, model_input, stop):\n            return next(responses)\n\n    return StubPolicy()\n\n\ndef _run_e2e_rollout():\n    \"\"\"Run end-to-end rollout and return (trajectory, data).\"\"\"\n    env = build_agent_tool_env(\n        renderer=_make_stub_renderer(),\n        tools=[_stub_tool],\n        initial_messages=[\n            Message(role=\"system\", content=\"You are helpful\"),\n            Message(role=\"user\", content=\"Do something\"),\n        ],\n        reward_fn=_zero_reward_fn,\n        max_turns=5,\n    )\n    traj = asyncio.run(do_single_rollout(_make_stub_policy(), env))\n    data = trajectory_to_data(traj, traj_advantage=1.0)\n    return traj, data\n\n\nclass TestEndToEndToolUseRollout:\n    def test_trajectory_has_two_transitions(self):\n        traj, _ = _run_e2e_rollout()\n        assert len(traj.transitions) == 2\n        assert traj.transitions[0].episode_done is False\n        assert traj.transitions[1].episode_done is True\n\n    def test_produces_single_datum(self):\n        _, data = _run_e2e_rollout()\n        assert len(data) == 1\n\n    def test_mask_only_on_agent_tokens(self):\n        _, data = _run_e2e_rollout()\n        mask = _get_mask(data[0])\n        # Full sequence: [100,101, 200,201, 500, 300,301, 400,401, 500, 310,311]\n        #                 sys      user     hdr  ac1      tool     hdr  ac2\n        # Full mask:      [0,  0,   0,  0,   0,   1,  1,   0,  0,   0,   1,  1]\n        # After [1:] shift:\n        # Mask:           [0,  0,   0,   0,   1,   1,   0,   0,   0,   1,   1]\n        expected = [0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1]\n        assert mask == expected\n        assert sum(mask) == 4  # 2 tokens per action × 2 actions\n"
  },
  {
    "path": "tinker_cookbook/rl/play_w_env.py",
    "content": "\"\"\"\nTo help you debug your environment, you can use the play_env function to play as the policy by typing in your responses in an environment interactively.\n\nOptions:\n- multiline=True: Enable multi-line input mode (terminate with two blank lines)\n\nWe include an example of playing the Twenty Questions environment in the main function.\nYou can run it with:\n\n```bash\npython -m tinker_cookbook.rl.play_w_env\n```\n\"\"\"\n\nimport asyncio\n\nimport tinker\nfrom termcolor import colored\n\nfrom tinker_cookbook.completers import (\n    StopCondition,\n    TokenCompleter,\n    TokensWithLogprobs,\n)\nfrom tinker_cookbook.rl.rollouts import do_single_rollout\nfrom tinker_cookbook.rl.types import Env, Trajectory\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n\nasync def get_async_input(prompt: str, multiline: bool = False) -> str:\n    loop = asyncio.get_event_loop()\n    if not multiline:\n        return await loop.run_in_executor(None, input, prompt)\n\n    # Multiline mode: collect lines until two consecutive blank lines\n    print(prompt + \" (enter two blank lines when done)\")\n    lines = []\n    prev_line = None\n    while True:\n        line = await loop.run_in_executor(None, input, \"\")\n        if line == \"\" and prev_line == \"\":\n            # Remove the first blank line from the list\n            if lines and lines[-1] == \"\":\n                lines.pop()\n            break\n        lines.append(line)\n        prev_line = line\n    return \"\\n\".join(lines)\n\n\nclass ManualPolicy(TokenCompleter):\n    def __init__(self, tokenizer: Tokenizer, multiline: bool = True, show_observation: bool = True):\n        self.tokenizer = tokenizer\n        self.step_count = 0\n        self.multiline = multiline\n        self.show_observation = show_observation\n\n    async def __call__(self, ob: tinker.ModelInput, stop: StopCondition) -> TokensWithLogprobs:\n        if self.show_observation:\n            observation_str = self.tokenizer.decode(ob.to_ints())\n            print(colored(f\"\\n--- Step {self.step_count} ---\", \"green\"))\n            print(colored(\"Observation:\", \"blue\"))\n            print(observation_str)\n            print(colored(\"-\" * 60, \"green\"))\n\n        prompt_text = \"Your action:\" if self.multiline else \"Your action: \"\n        action_str = await get_async_input(colored(prompt_text, \"yellow\"), multiline=self.multiline)\n        action_tokens = self.tokenizer.encode(action_str, add_special_tokens=False)\n        self.step_count += 1\n        return TokensWithLogprobs(tokens=action_tokens, maybe_logprobs=None)\n\n\ndef print_trajectory_summary(trajectory: Trajectory):\n    \"\"\"Print a summary of the completed trajectory.\"\"\"\n    print(colored(\"\\n=== Game Summary ===\", \"cyan\", attrs=[\"bold\"]))\n    total_reward = sum(t.reward for t in trajectory.transitions)\n    print(f\"Total steps: {len(trajectory.transitions)}\")\n    print(f\"Total reward: {total_reward}\")\n\n    if trajectory.transitions:\n        print(\"\\nReward per step:\")\n        for i, transition in enumerate(trajectory.transitions):\n            if transition.reward != 0:\n                print(f\"  Step {i}: reward = {transition.reward}\")\n\n    print(colored(\"===================\", \"cyan\", attrs=[\"bold\"]))\n\n\nasync def play_env(\n    env: Env, tokenizer: Tokenizer, multiline: bool = True, show_observation: bool = True\n):\n    \"\"\"Play a single-player environment interactively.\"\"\"\n    print(colored(\"Starting interactive environment session...\", \"cyan\", attrs=[\"bold\"]))\n    print(\"Type your actions when prompted. The episode will end when the episode is done.\")\n\n    policy = ManualPolicy(tokenizer, multiline=multiline, show_observation=show_observation)\n    trajectory = await do_single_rollout(policy, env)\n\n    print_trajectory_summary(trajectory)\n    return trajectory\n\n\nasync def main():\n    from tinker_cookbook.recipes.multiplayer_rl.twenty_questions.env import (\n        construct_minimal_20q_env,\n    )\n\n    answer = \"apple\"\n    env = construct_minimal_20q_env(answer)\n    await play_env(env, env.renderer.tokenizer)\n\n\nif __name__ == \"__main__\":\n    asyncio.run(main())\n"
  },
  {
    "path": "tinker_cookbook/rl/preference_envs.py",
    "content": "import logging\nfrom collections.abc import Callable, Sequence\nfrom dataclasses import dataclass\nfrom enum import StrEnum\n\nimport chz\nimport tinker\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.completers import StopCondition\nfrom tinker_cookbook.exceptions import ConfigurationError\nfrom tinker_cookbook.preference.preference_datasets import (\n    ComparisonDatasetBuilder,\n)\nfrom tinker_cookbook.preference.types import (\n    Comparison,\n    LabeledComparison,\n    PreferenceModel,\n)\nfrom tinker_cookbook.rl.types import (\n    Action,\n    Env,\n    EnvGroupBuilder,\n    Metrics,\n    Observation,\n    RLDataset,\n    RLDatasetBuilder,\n    StepResult,\n    Trajectory,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.utils import logtree\nfrom tinker_cookbook.utils.logtree_formatters import ConversationFormatter\nfrom tinker_cookbook.utils.misc_utils import safezip\n\nlogger = logging.getLogger(__name__)\n\n\nclass PreferenceEnv(Env):\n    def __init__(\n        self,\n        convo_prefix: list[renderers.Message],\n        policy_renderer: renderers.Renderer,\n    ):\n        self.convo_prefix = convo_prefix\n        self.policy_renderer = policy_renderer\n\n    @property\n    def stop_condition(self) -> StopCondition:\n        return self.policy_renderer.get_stop_sequences()\n\n    async def initial_observation(self) -> tuple[Observation, StopCondition]:\n        return self.policy_renderer.build_generation_prompt(self.convo_prefix), self.stop_condition\n\n    async def step(self, action: Action) -> StepResult:\n        \"\"\"Compute the reward for a given action.\n\n        Args:\n            tokens: The tokens to compute the reward for.\n\n        Returns:\n            A tuple containing:\n                - reward (float): The reward for the given action.\n                - metrics (Dict[str, float]): Additional metrics to track.\n        \"\"\"\n        return StepResult(\n            reward=0,\n            episode_done=True,\n            next_observation=tinker.ModelInput.empty(),\n            next_stop_condition=self.stop_condition,\n            metrics={},\n        )\n\n\nclass TournamentPattern(StrEnum):\n    ALL_PAIRS_BOTH_WAYS = \"all_pairs_both_ways\"\n    ALL_PAIRS_ONE_WAY = \"all_pairs_one_way\"\n\n\ndef get_pairs_chunked(n: int, pattern: TournamentPattern, chunk_size: int) -> list[tuple[int, int]]:\n    \"\"\"\n    Get pairs of indices of matchups of n players. If chunk_size < n, then we divide the players\n    into groups of at most chunk_size and get the matchup indices within each group.\n    \"\"\"\n    out = []\n    for chunk_start in range(0, n, chunk_size):\n        chunk_end = min(chunk_start + chunk_size, n)\n        pairs = get_pairs(chunk_end - chunk_start, pattern)\n        for i, j in pairs:\n            out.append((chunk_start + i, chunk_start + j))\n    return out\n\n\ndef get_pairs(n: int, pattern: TournamentPattern) -> list[tuple[int, int]]:\n    if pattern == TournamentPattern.ALL_PAIRS_BOTH_WAYS:\n        return [(i, j) for i in range(n) for j in range(n) if i != j]\n    elif pattern == TournamentPattern.ALL_PAIRS_ONE_WAY:\n        return [(i, j) for i in range(n) for j in range(i + 1, n)]\n    else:\n        raise ConfigurationError(f\"Invalid tournament pattern: {pattern}\")\n\n\n@dataclass(frozen=True)\nclass PairwisePreferenceGroupBuilder(EnvGroupBuilder):\n    convo_prefix: list[renderers.Message]\n    policy_renderer: renderers.Renderer\n    tournament_pattern: TournamentPattern\n    preference_model: PreferenceModel\n    num_envs: int\n    content_preprocessor: Callable[[str], str] | None = None  # e.g. strip out <thinking> tags\n    matchup_group_size: int = 4  # divide group into smaller groups of this size for matchups\n    eval_target_completion_A: list[renderers.Message] | None = None\n\n    async def make_envs(self) -> Sequence[Env]:\n        return [\n            PreferenceEnv(self.convo_prefix, self.policy_renderer) for _ in range(self.num_envs)\n        ]\n\n    def _preprocess_message(self, message: renderers.Message) -> renderers.Message:\n        if self.content_preprocessor is not None:\n            content = renderers.get_text_content(message)\n            message = {**message, \"content\": self.content_preprocessor(content)}\n        return message\n\n    def get_response_message(self, trajectory: Trajectory) -> tuple[list[renderers.Message], bool]:\n        response, is_valid = self.policy_renderer.parse_response(\n            trajectory.transitions[0].ac.tokens\n        )\n        return [response], is_valid\n\n    def comparison_reward_for_second_messages(\n        self, message_i: list[renderers.Message], message_j: list[renderers.Message]\n    ) -> Comparison:\n        comparison = Comparison(\n            prompt_conversation=self.convo_prefix,\n            completion_A=[self._preprocess_message(m) for m in message_i],\n            completion_B=[self._preprocess_message(m) for m in message_j],\n        )\n        return comparison\n\n    @logtree.scope_header_decorator\n    async def compute_group_rewards(\n        self,\n        trajectory_group: list[Trajectory],\n        env_group: Sequence[Env],\n    ) -> list[tuple[float, Metrics]]:\n        assert all(len(trajectory.transitions) == 1 for trajectory in trajectory_group)\n        # Get response from each trajectory\n        response_tuples = [self.get_response_message(trajectory) for trajectory in trajectory_group]\n        response_messages, is_valid_list = safezip(*response_tuples)\n\n        # Log prompt\n        with logtree.scope_header(\"Prompt\"):\n            logtree.log_formatter(ConversationFormatter(messages=self.convo_prefix))\n\n        # Log trajectories\n        for idx, (messages, is_valid) in enumerate(\n            zip(response_messages, is_valid_list, strict=True)\n        ):\n            with logtree.scope_header(f\"Completion {idx}\"):\n                logtree.log_formatter(ConversationFormatter(messages=messages))\n                logtree.log_text(f\"Valid format: {is_valid}\")\n\n        # if the matching group size is 3 and len(response_messages) is 6\n        # then it will return something like\n        # [(0, 1), (0, 2), (1, 2), (3, 4), (3, 5), (4, 5)]\n        # so we don't end up with O(n^2) comparisons\n        comparison_indices_pairs = get_pairs_chunked(\n            len(response_messages), self.tournament_pattern, self.matchup_group_size\n        )\n\n        logtree.log_text(\n            f\"Got {len(trajectory_group)} trajectories, doing {len(comparison_indices_pairs)} pairwise matchups.\"\n        )\n\n        j_comparisons = [\n            self.comparison_reward_for_second_messages(\n                message_i=response_messages[i], message_j=response_messages[j]\n            )\n            for i, j in comparison_indices_pairs\n        ]\n\n        # Log each pairwise comparison with its reward\n        with logtree.scope_header(\"Pairwise Comparisons\"):\n            j_rewards = []\n\n            # Compute all rewards first\n            for comparison in j_comparisons:\n                reward = await self.preference_model(comparison)\n                j_rewards.append(reward)\n\n            # Log summary of all matchups\n            for idx, ((i, j), reward) in enumerate(\n                zip(comparison_indices_pairs, j_rewards, strict=True)\n            ):\n                logtree.log_text(f\"Matchup {idx}: ({i} vs {j}) — Reward: {reward:.2f}\")\n\n        win_minus_loss_list = [0.0 for _ in range(len(response_messages))]\n        matchup_count = [0 for _ in range(len(response_messages))]\n        for (i, j), j_reward in safezip(comparison_indices_pairs, j_rewards):\n            win_minus_loss_list[j] += j_reward\n            win_minus_loss_list[i] -= j_reward\n            matchup_count[j] += 1\n            matchup_count[i] += 1\n        format_coef = 1.0\n\n        return [\n            (\n                win_minus_loss / matchup_count + format_coef * (float(is_valid) - 1.0),\n                {\"win_minus_loss\": win_minus_loss / matchup_count, \"format\": is_valid},\n            )\n            for win_minus_loss, is_valid, matchup_count in safezip(\n                win_minus_loss_list, is_valid_list, matchup_count\n            )\n        ]\n\n    def logging_tags(self) -> list[str]:\n        return [\"pair_pref\"]\n\n\nclass PairwisePreferenceDataset(RLDataset):\n    def __init__(\n        self,\n        comparison_builder: ComparisonDatasetBuilder,\n        renderer: renderers.Renderer,\n        batch_size: int,\n        preference_model: PreferenceModel,\n        tournament_pattern: TournamentPattern = TournamentPattern.ALL_PAIRS_BOTH_WAYS,\n        group_size: int = 4,\n        content_preprocessor: Callable[[str], str] | None = None,\n    ):\n        self.comparison_builder = comparison_builder\n        self.renderer = renderer\n        self.batch_size = batch_size\n        self.preference_model = preference_model\n        self.train_dataset, _ = self.comparison_builder.get_train_and_test_datasets()\n        self.tournament_pattern = tournament_pattern\n        self.group_size = group_size\n        self.content_preprocessor = content_preprocessor\n\n    def get_batch(self, index: int) -> list[EnvGroupBuilder]:\n        rows = self.train_dataset.select(\n            range(index * self.batch_size, (index + 1) * self.batch_size)\n        )\n        lcs = [self.comparison_builder.example_to_labeled_comparison(row) for row in rows]  # type: ignore\n        return [self._labeled_comparison_to_env_group(lc) for lc in lcs if lc is not None]\n\n    def _labeled_comparison_to_env_group(self, lc: LabeledComparison) -> EnvGroupBuilder:\n        return PairwisePreferenceGroupBuilder(\n            convo_prefix=lc.comparison.prompt_conversation,\n            policy_renderer=self.renderer,\n            preference_model=self.preference_model,\n            tournament_pattern=self.tournament_pattern,\n            num_envs=self.group_size,\n            content_preprocessor=self.content_preprocessor,\n        )\n\n    def __len__(self) -> int:\n        return len(self.train_dataset) // self.batch_size\n\n\n@chz.chz\nclass PairwisePreferenceRLDatasetBuilder(RLDatasetBuilder):\n    comparison_builder: ComparisonDatasetBuilder\n    batch_size: int\n    policy_renderer_name: str\n    policy_model_name: str\n    tournament_pattern: TournamentPattern = TournamentPattern.ALL_PAIRS_BOTH_WAYS\n    group_size: int\n    content_preprocessor: Callable[[str], str] | None = None\n    preference_model_builder: Callable[[], PreferenceModel]\n\n    async def __call__(self) -> tuple[PairwisePreferenceDataset, None]:\n        policy_renderer = renderers.get_renderer(\n            self.policy_renderer_name, get_tokenizer(self.policy_model_name)\n        )\n        return PairwisePreferenceDataset(\n            comparison_builder=self.comparison_builder,\n            renderer=policy_renderer,\n            batch_size=self.batch_size,\n            preference_model=self.preference_model_builder(),\n            tournament_pattern=self.tournament_pattern,\n            group_size=self.group_size,\n            content_preprocessor=self.content_preprocessor,\n        ), None\n"
  },
  {
    "path": "tinker_cookbook/rl/problem_env.py",
    "content": "import logging\nfrom abc import abstractmethod\nfrom collections.abc import Callable, Sequence\nfrom dataclasses import dataclass\n\nimport tinker\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.completers import StopCondition\nfrom tinker_cookbook.rl.types import (\n    Action,\n    Env,\n    EnvGroupBuilder,\n    Metrics,\n    Observation,\n    StepResult,\n    Trajectory,\n)\nfrom tinker_cookbook.utils import logtree\nfrom tinker_cookbook.utils.logtree_formatters import ConversationFormatter\n\nlogger = logging.getLogger(__name__)\n\n\nclass ProblemEnv(Env):\n    def __init__(\n        self,\n        renderer: renderers.Renderer,\n        convo_prefix: list[renderers.Message] | None = None,\n        format_coef: float = 0.1,\n    ):\n        self.renderer = renderer\n        self.convo_prefix = convo_prefix or []\n        self.format_coef = format_coef\n\n    @property\n    def stop_condition(self) -> StopCondition:\n        return self.renderer.get_stop_sequences()\n\n    @abstractmethod\n    def get_question(self) -> str:\n        pass\n\n    @abstractmethod\n    def check_answer(self, sample_str: str) -> bool:\n        pass\n\n    @abstractmethod\n    def check_format(self, sample_str: str) -> bool:\n        pass\n\n    @abstractmethod\n    def get_reference_answer(self) -> str:\n        \"\"\"Return the reference answer for logging purposes.\"\"\"\n        pass\n\n    async def initial_observation(self) -> tuple[Observation, StopCondition]:\n        convo = self.convo_prefix + [\n            {\"role\": \"user\", \"content\": self.get_question()},\n        ]\n        return self.renderer.build_generation_prompt(convo), self.stop_condition\n\n    async def step(self, action: Action) -> StepResult:\n        convo = self.convo_prefix + [{\"role\": \"user\", \"content\": self.get_question()}]\n        message, parse_success = self.renderer.parse_response(action)\n        content = renderers.get_text_content(message)\n        correct_format = float(parse_success) and float(self.check_format(content))\n        correct_answer = float(self.check_answer(content))\n        total_reward = self.format_coef * (correct_format - 1) + correct_answer\n\n        # Log the attempt in a fixed structure that scales to longer content.\n        with logtree.scope_header(\"Prompt\"):\n            logtree.log_formatter(ConversationFormatter(messages=convo))\n        with logtree.scope_header(\"Policy Response\"):\n            logtree.log_formatter(ConversationFormatter(messages=[message]))\n        with logtree.scope_header(\"Reward\"):\n            logtree.table_from_dict(\n                {\n                    \"reference_answer\": self.get_reference_answer(),\n                    \"format_valid\": bool(correct_format),\n                    \"correct\": bool(correct_answer),\n                    \"format_coef\": self.format_coef,\n                    \"reward\": f\"{total_reward:.3f}\",\n                },\n                caption=\"Reward components\",\n            )\n\n        return StepResult(\n            reward=total_reward,\n            episode_done=True,\n            next_observation=tinker.ModelInput.empty(),\n            next_stop_condition=self.stop_condition,\n            metrics={\n                \"format\": correct_format,\n                \"correct\": correct_answer,\n            },\n        )\n\n\n@dataclass(frozen=True)\nclass ProblemGroupBuilder(EnvGroupBuilder):\n    env_thunk: Callable[[], ProblemEnv]\n    num_envs: int\n    dataset_name: str = \"problems\"\n\n    async def make_envs(self) -> Sequence[Env]:\n        return [self.env_thunk() for _ in range(self.num_envs)]\n\n    async def compute_group_rewards(\n        self, trajectory_group: list[Trajectory], env_group: Sequence[Env]\n    ) -> list[tuple[float, Metrics]]:\n        return [(0.0, {}) for _ in range(len(trajectory_group))]\n\n    def logging_tags(self) -> list[str]:\n        return [self.dataset_name]\n"
  },
  {
    "path": "tinker_cookbook/rl/rollout_error_resilience_test.py",
    "content": "\"\"\"Tests for rollout error resilience: strategy abstraction, retry, error tracking, and pickling.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport pickle\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\nimport tinker\n\nfrom tinker_cookbook.completers import TokenCompleter, TokensWithLogprobs\nfrom tinker_cookbook.exceptions import AllTrajectoriesFailedError, ConfigurationError\nfrom tinker_cookbook.rl.rollout_strategy import (\n    FailFast,\n    RetryOnFailure,\n    rollout_strategy_from_config,\n)\nfrom tinker_cookbook.rl.rollouts import (\n    RolloutErrorCounter,\n    _do_group_rollout_and_filter_constant_reward_impl,\n    do_group_rollout,\n)\nfrom tinker_cookbook.rl.types import (\n    Env,\n    EnvGroupBuilder,\n    RolloutError,\n    StepResult,\n    Trajectory,\n    TrajectoryGroup,\n    Transition,\n)\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _make_trajectory() -> Trajectory:\n    \"\"\"Create a minimal valid Trajectory.\"\"\"\n    return Trajectory(\n        transitions=[\n            Transition(\n                ob=tinker.ModelInput.from_ints([1, 2, 3]),\n                ac=TokensWithLogprobs(tokens=[4, 5], maybe_logprobs=[0.1, 0.2]),\n                reward=1.0,\n                episode_done=True,\n            )\n        ],\n        final_ob=tinker.ModelInput.from_ints([]),\n    )\n\n\nclass _FakePolicy(TokenCompleter):\n    \"\"\"Policy that returns a fixed result, optionally failing on specific call indices.\"\"\"\n\n    def __init__(self, fail_indices: set[int] | None = None, error: BaseException | None = None):\n        self._call_count = 0\n        self.fail_indices = fail_indices or set()\n        self.error = error or RuntimeError(\"fake error\")\n\n    async def __call__(self, model_input, stop):\n        idx = self._call_count\n        self._call_count += 1\n        if idx in self.fail_indices:\n            raise self.error\n        return TokensWithLogprobs(tokens=[4, 5], maybe_logprobs=[0.1, 0.2])\n\n\nclass _FakeEnv(Env):\n    async def initial_observation(self):\n        return tinker.ModelInput.from_ints([1, 2, 3]), [0]\n\n    async def step(self, action):\n        return StepResult(\n            reward=1.0,\n            episode_done=True,\n            next_observation=tinker.ModelInput.from_ints([]),\n            next_stop_condition=[0],\n        )\n\n\nclass _FakeEnvGroupBuilder(EnvGroupBuilder):\n    def __init__(self, n_envs: int = 4):\n        self.n_envs = n_envs\n        self.make_envs_call_count = 0\n\n    async def make_envs(self):\n        self.make_envs_call_count += 1\n        return [_FakeEnv() for _ in range(self.n_envs)]\n\n\n# ---------------------------------------------------------------------------\n# rollout_strategy_from_config tests\n# ---------------------------------------------------------------------------\n\n\nclass TestRolloutStrategyFromConfig:\n    def test_false_returns_fail_fast(self):\n        strategy = rollout_strategy_from_config(False)\n        assert isinstance(strategy, FailFast)\n        assert not strategy.catches_group_errors\n\n    def test_true_returns_retry_on_failure(self):\n        strategy = rollout_strategy_from_config(True)\n        assert isinstance(strategy, RetryOnFailure)\n        assert strategy.max_retries == 3\n        assert strategy.catches_group_errors\n\n    def test_strategy_instance_passed_through(self):\n        strategy = RetryOnFailure(max_retries=5)\n        assert rollout_strategy_from_config(strategy) is strategy\n\n    def test_fail_fast_instance_passed_through(self):\n        strategy = FailFast()\n        assert rollout_strategy_from_config(strategy) is strategy\n\n    def test_invalid_value_raises(self):\n        with pytest.raises(ConfigurationError):\n            rollout_strategy_from_config(0.5)  # type: ignore[arg-type]\n\n\n# ---------------------------------------------------------------------------\n# Strategy pickling tests\n# ---------------------------------------------------------------------------\n\n\nclass TestStrategyPickle:\n    def test_fail_fast_pickleable(self):\n        strategy = FailFast()\n        restored = pickle.loads(pickle.dumps(strategy))\n        assert isinstance(restored, FailFast)\n\n    def test_retry_on_failure_pickleable(self):\n        strategy = RetryOnFailure(max_retries=5)\n        restored = pickle.loads(pickle.dumps(strategy))\n        assert isinstance(restored, RetryOnFailure)\n        assert restored.max_retries == 5\n\n\n# ---------------------------------------------------------------------------\n# RolloutError and TrajectoryGroup tests\n# ---------------------------------------------------------------------------\n\n\nclass TestRolloutError:\n    def test_pickleable(self):\n        err = RolloutError(error_type=\"BadRequestError\", error_message=\"context overflow\")\n        restored = pickle.loads(pickle.dumps(err))\n        assert restored.error_type == \"BadRequestError\"\n        assert restored.error_message == \"context overflow\"\n\n\nclass TestTrajectoryGroupErrors:\n    def test_default_no_errors(self):\n        tg = TrajectoryGroup([_make_trajectory()], [1.0], [{}])\n        assert tg.rollout_errors == []\n\n    def test_with_errors(self):\n        errors = [RolloutError(\"BadRequestError\", \"too long\")]\n        tg = TrajectoryGroup([_make_trajectory()], [1.0], [{}], rollout_errors=errors)\n        assert len(tg.rollout_errors) == 1\n\n    def test_pickleable_with_errors(self):\n        errors = [RolloutError(\"BadRequestError\", \"too long\")]\n        tg = TrajectoryGroup([_make_trajectory()], [1.0], [{}], rollout_errors=errors)\n        restored = pickle.loads(pickle.dumps(tg))\n        assert len(restored.rollout_errors) == 1\n\n    def test_get_total_rewards_unaffected(self):\n        errors = [RolloutError(\"Err\", \"msg\")]\n        tg = TrajectoryGroup(\n            [_make_trajectory(), _make_trajectory()],\n            [0.5, 0.3],\n            [{}, {}],\n            rollout_errors=errors,\n        )\n        rewards = tg.get_total_rewards()\n        assert rewards[0] == pytest.approx(1.5)\n        assert rewards[1] == pytest.approx(1.3)\n\n\n# ---------------------------------------------------------------------------\n# RolloutErrorCounter tests\n# ---------------------------------------------------------------------------\n\n\nclass TestRolloutErrorCounter:\n    def test_ingest_successful_group(self):\n        counter = RolloutErrorCounter()\n        tg = TrajectoryGroup([_make_trajectory()], [1.0], [{}])\n        counter.ingest(tg)\n        assert counter.get_metrics() == {}\n\n    def test_ingest_none_increments_groups_skipped(self):\n        counter = RolloutErrorCounter()\n        counter.ingest(None)\n        counter.ingest(None)\n        metrics = counter.get_metrics()\n        assert metrics[\"rollout_errors/groups_skipped\"] == 2.0\n\n    def test_ingest_group_with_errors(self):\n        counter = RolloutErrorCounter()\n        errors = [\n            RolloutError(\"BadRequestError\", \"msg1\"),\n            RolloutError(\"BadRequestError\", \"msg2\"),\n            RolloutError(\"TimeoutError\", \"msg3\"),\n        ]\n        tg = TrajectoryGroup([_make_trajectory()], [1.0], [{}], rollout_errors=errors)\n        counter.ingest(tg)\n        metrics = counter.get_metrics()\n        assert metrics[\"rollout_errors/BadRequestError\"] == 2.0\n        assert metrics[\"rollout_errors/TimeoutError\"] == 1.0\n        assert metrics[\"rollout_errors/total\"] == 3.0\n\n    def test_cumulative_across_ingests(self):\n        counter = RolloutErrorCounter()\n        tg1 = TrajectoryGroup(\n            [_make_trajectory()],\n            [1.0],\n            [{}],\n            rollout_errors=[RolloutError(\"BadRequestError\", \"a\")],\n        )\n        counter.ingest(tg1)\n        counter.ingest(None)\n        metrics = counter.get_metrics()\n        assert metrics[\"rollout_errors/BadRequestError\"] == 1.0\n        assert metrics[\"rollout_errors/groups_skipped\"] == 1.0\n\n\n# ---------------------------------------------------------------------------\n# FailFast strategy tests (via do_group_rollout)\n# ---------------------------------------------------------------------------\n\n\nclass TestFailFastStrategy:\n    def test_default_strategy_raises_on_error(self):\n        \"\"\"Without strategy (FailFast default), errors propagate.\"\"\"\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        policy = _FakePolicy(fail_indices={1})\n        with pytest.raises(RuntimeError, match=\"fake error\"):\n            asyncio.run(do_group_rollout(builder, policy))\n\n    def test_success_returns_all_trajectories(self):\n        builder = _FakeEnvGroupBuilder(n_envs=3)\n        policy = _FakePolicy()\n        tg = asyncio.run(do_group_rollout(builder, policy))\n        assert len(tg.trajectories_G) == 3\n        assert tg.rollout_errors == []\n\n    def test_cancelled_error_propagates(self):\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        policy = _FakePolicy(fail_indices={0}, error=asyncio.CancelledError())\n        with pytest.raises(asyncio.CancelledError):\n            asyncio.run(do_group_rollout(builder, policy))\n\n\n# ---------------------------------------------------------------------------\n# RetryOnFailure strategy tests (via do_group_rollout)\n# ---------------------------------------------------------------------------\n\n\nclass TestRetryOnFailureStrategy:\n    def test_no_errors_returns_all_trajectories(self):\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        policy = _FakePolicy()\n        tg = asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=3)))\n        assert len(tg.trajectories_G) == 2\n        assert tg.rollout_errors == []\n\n    def test_retry_recovers_from_transient_failure(self):\n        \"\"\"One trajectory fails initially, retry succeeds.\"\"\"\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        # Call index 1 fails, but retry (index 2) succeeds\n        policy = _FakePolicy(fail_indices={1})\n        tg = asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=3)))\n        # Original success + retry success = 2 trajectories\n        assert len(tg.trajectories_G) == 2\n        assert len(tg.rollout_errors) == 1\n        assert tg.rollout_errors[0].error_type == \"RuntimeError\"\n\n    def test_retry_creates_fresh_envs(self):\n        \"\"\"Retry calls make_envs again to get a fresh environment.\"\"\"\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        policy = _FakePolicy(fail_indices={1})  # one failure triggers one retry\n        asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=3)))\n        # Initial make_envs + 1 retry make_envs\n        assert builder.make_envs_call_count == 2\n\n    def test_all_fail_raises_after_retries(self):\n        \"\"\"All trajectories fail and retries exhausted -> re-raises last error.\"\"\"\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        # All calls fail (indices 0,1 initial + 2,3,4 retries = 5 total calls)\n        policy = _FakePolicy(fail_indices={0, 1, 2, 3, 4})\n        with pytest.raises(RuntimeError, match=\"fake error\"):\n            asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=3)))\n\n    def test_budget_exhausted_cancels_and_raises(self):\n        \"\"\"When retry budget runs out, cancel remaining tasks and re-raise.\"\"\"\n        builder = _FakeEnvGroupBuilder(n_envs=4)\n        # Indices 0,1 succeed; 2 fails, retry at 4 fails; 3 fails, retry at 5 fails\n        # Budget is 2 — first retry succeeds at index 4? No: indices 2,3 fail, retries at 4,5 also fail\n        # After 2 retries exhausted, next failure re-raises\n        policy = _FakePolicy(fail_indices={2, 3, 4, 5})\n        with pytest.raises(RuntimeError, match=\"fake error\"):\n            asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=2)))\n\n    def test_zero_retries_raises_on_any_failure(self):\n        \"\"\"max_retries=0 means no retries — any failure crashes the group.\"\"\"\n        builder = _FakeEnvGroupBuilder(n_envs=4)\n        policy = _FakePolicy(fail_indices={2})\n        with pytest.raises(RuntimeError, match=\"fake error\"):\n            asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=0)))\n\n    def test_cancelled_error_not_swallowed(self):\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        policy = _FakePolicy(fail_indices={0}, error=asyncio.CancelledError())\n        with pytest.raises(asyncio.CancelledError):\n            asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=3)))\n\n    def test_keyboard_interrupt_not_swallowed(self):\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        policy = _FakePolicy(fail_indices={0}, error=KeyboardInterrupt())\n        with pytest.raises(KeyboardInterrupt):\n            asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=3)))\n\n    def test_make_envs_failure_during_retry_propagates(self):\n        \"\"\"If make_envs() fails during retry, the error propagates.\"\"\"\n\n        call_count = 0\n\n        class _FailOnSecondMakeEnvs(EnvGroupBuilder):\n            async def make_envs(self):\n                nonlocal call_count\n                call_count += 1\n                if call_count > 1:\n                    raise RuntimeError(\"container pool exhausted\")\n                return [_FakeEnv() for _ in range(2)]\n\n        builder = _FailOnSecondMakeEnvs()\n        policy = _FakePolicy(fail_indices={1})  # triggers a retry\n        with pytest.raises(RuntimeError, match=\"container pool exhausted\"):\n            asyncio.run(do_group_rollout(builder, policy, strategy=RetryOnFailure(max_retries=3)))\n\n\n# ---------------------------------------------------------------------------\n# _do_group_rollout_and_filter_constant_reward_impl tests\n# ---------------------------------------------------------------------------\n\n\nclass TestImplErrorHandling:\n    def test_fail_fast_propagates_error(self):\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        sampling_client = MagicMock(spec=tinker.SamplingClient)\n        with (\n            patch(\n                \"tinker_cookbook.rl.rollouts.do_group_rollout\",\n                side_effect=RuntimeError(\"boom\"),\n            ),\n            pytest.raises(RuntimeError, match=\"boom\"),\n        ):\n            asyncio.run(\n                _do_group_rollout_and_filter_constant_reward_impl(\n                    sampling_client,\n                    builder,\n                    max_tokens=100,\n                    temperature=1.0,\n                    do_remove_constant_reward_groups=False,\n                    strategy=FailFast(),\n                )\n            )\n\n    def test_retry_strategy_returns_none_on_group_error(self):\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        sampling_client = MagicMock(spec=tinker.SamplingClient)\n        with patch(\n            \"tinker_cookbook.rl.rollouts.do_group_rollout\",\n            side_effect=RuntimeError(\"boom\"),\n        ):\n            result = asyncio.run(\n                _do_group_rollout_and_filter_constant_reward_impl(\n                    sampling_client,\n                    builder,\n                    max_tokens=100,\n                    temperature=1.0,\n                    do_remove_constant_reward_groups=False,\n                    strategy=RetryOnFailure(max_retries=3),\n                )\n            )\n        assert result is None\n\n    def test_all_trajectories_failed_returns_none(self):\n        builder = _FakeEnvGroupBuilder(n_envs=2)\n        sampling_client = MagicMock(spec=tinker.SamplingClient)\n        with patch(\n            \"tinker_cookbook.rl.rollouts.do_group_rollout\",\n            side_effect=AllTrajectoriesFailedError(\"all failed\"),\n        ):\n            result = asyncio.run(\n                _do_group_rollout_and_filter_constant_reward_impl(\n                    sampling_client,\n                    builder,\n                    max_tokens=100,\n                    temperature=1.0,\n                    do_remove_constant_reward_groups=False,\n                    strategy=RetryOnFailure(max_retries=3),\n                )\n            )\n        assert result is None\n"
  },
  {
    "path": "tinker_cookbook/rl/rollout_logging.py",
    "content": "\"\"\"Utilities for exporting per-rollout records to JSONL.\"\"\"\n\nimport json\nimport logging\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any\n\nfrom tinker_cookbook.rl.types import TrajectoryGroup\nfrom tinker_cookbook.utils.misc_utils import safezip\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass(frozen=True)\nclass RolloutSummaryExportConfig:\n    \"\"\"Location and metadata for one rollout-summary JSONL export.\"\"\"\n\n    path: Path\n    split: str\n    iteration: int\n    sampling_client_step: int | None = None\n\n\n@dataclass(frozen=True)\nclass RolloutSummaryGroup:\n    \"\"\"One group of trajectories to serialize into rollout-summary JSONL records.\"\"\"\n\n    trajectory_group: TrajectoryGroup\n    tags: list[str]\n    sampling_client_step: int | None = None\n\n\ndef _json_safe(value: Any) -> Any:\n    \"\"\"Convert values to JSON-serializable form.\"\"\"\n    if value is None or isinstance(value, (str, bool, int, float)):\n        return value\n    if isinstance(value, dict):\n        return {str(k): _json_safe(v) for k, v in value.items()}\n    if isinstance(value, (list, tuple)):\n        return [_json_safe(v) for v in value]\n    if hasattr(value, \"item\"):\n        try:\n            return value.item()\n        except Exception:\n            logger.debug(\"Failed to convert %r via .item(), falling back to str()\", type(value))\n    return str(value)\n\n\ndef write_rollout_summaries_jsonl(\n    path: str | Path,\n    *,\n    split: str,\n    iteration: int,\n    trajectory_groups_P: Sequence[TrajectoryGroup],\n    taglist_P: Sequence[list[str]],\n    sampling_client_steps_P: Sequence[int | None] | None = None,\n) -> None:\n    \"\"\"\n    Write one JSON record per rollout trajectory.\n\n    This is intentionally disaggregated: no aggregate or summary statistics.\n    \"\"\"\n    path_obj = Path(path)\n    path_obj.parent.mkdir(parents=True, exist_ok=True)\n\n    with open(path_obj, \"w\") as f:\n        for group_idx, (trajectory_group, tags) in enumerate(\n            safezip(trajectory_groups_P, taglist_P)\n        ):\n            total_rewards_G = trajectory_group.get_total_rewards()\n            sampling_step = (\n                sampling_client_steps_P[group_idx] if sampling_client_steps_P is not None else None\n            )\n\n            for traj_idx, trajectory in enumerate(trajectory_group.trajectories_G):\n                steps = []\n                for step_idx, transition in enumerate(trajectory.transitions):\n                    steps.append(\n                        {\n                            \"step_idx\": step_idx,\n                            \"ob_len\": transition.ob.length,\n                            \"ac_len\": len(transition.ac.tokens),\n                            \"reward\": transition.reward,\n                            \"episode_done\": transition.episode_done,\n                            \"metrics\": transition.metrics,\n                            \"logs\": transition.logs,\n                        }\n                    )\n\n                record = {\n                    \"schema_version\": 1,\n                    \"split\": split,\n                    \"iteration\": iteration,\n                    \"group_idx\": group_idx,\n                    \"traj_idx\": traj_idx,\n                    \"tags\": list(tags),\n                    \"sampling_client_step\": sampling_step,\n                    \"total_reward\": total_rewards_G[traj_idx],\n                    \"final_reward\": trajectory_group.final_rewards_G[traj_idx],\n                    \"trajectory_metrics\": trajectory_group.metrics_G[traj_idx],\n                    \"steps\": steps,\n                    \"final_ob_len\": trajectory.final_ob.length,\n                }\n                f.write(json.dumps(_json_safe(record)) + \"\\n\")\n\n\ndef rollout_summaries_jsonl_path(log_path: str, file_prefix: str) -> Path:\n    \"\"\"Build the rollout-summary JSONL path for a train/eval file prefix.\"\"\"\n    return Path(log_path) / f\"{file_prefix}_rollout_summaries.jsonl\"\n\n\ndef write_rollout_summaries_jsonl_from_groups(\n    path: Path,\n    *,\n    split: str,\n    iteration: int,\n    groups_P: Sequence[RolloutSummaryGroup],\n) -> None:\n    \"\"\"Serialize rollout summaries from grouped records with tags and sampler step metadata.\"\"\"\n    write_rollout_summaries_jsonl(\n        path,\n        split=split,\n        iteration=iteration,\n        trajectory_groups_P=[group.trajectory_group for group in groups_P],\n        taglist_P=[group.tags for group in groups_P],\n        sampling_client_steps_P=[group.sampling_client_step for group in groups_P],\n    )\n"
  },
  {
    "path": "tinker_cookbook/rl/rollout_logging_test.py",
    "content": "import json\nfrom pathlib import Path\nfrom typing import cast\n\nimport numpy as np\nimport tinker\n\nfrom tinker_cookbook.completers import TokensWithLogprobs\nfrom tinker_cookbook.rl.rollout_logging import write_rollout_summaries_jsonl\nfrom tinker_cookbook.rl.types import Logs, Metrics, Trajectory, TrajectoryGroup, Transition\n\n\ndef test_write_rollout_summaries_jsonl_handles_numpy_scalars(tmp_path: Path):\n    transition = Transition(\n        ob=tinker.ModelInput.from_ints([101, 102, 103, 104, 105]),\n        ac=TokensWithLogprobs(tokens=[1, 2, 3], maybe_logprobs=[-0.1, -0.2, -0.3]),\n        reward=cast(float, np.float32(0.25)),\n        episode_done=True,\n        metrics=cast(Metrics, {\"score\": np.float32(1.5)}),\n        logs=cast(Logs, {\"rank\": np.int64(2)}),\n    )\n    trajectory = Trajectory(transitions=[transition], final_ob=tinker.ModelInput.from_ints([1] * 8))\n    trajectory_group = TrajectoryGroup(\n        trajectories_G=[trajectory],\n        final_rewards_G=[cast(float, np.float32(0.75))],\n        metrics_G=[cast(Metrics, {\"traj_metric\": np.float32(3.0)})],\n    )\n    output_path = tmp_path / \"rollouts.jsonl\"\n\n    write_rollout_summaries_jsonl(\n        output_path,\n        split=\"train\",\n        iteration=1,\n        trajectory_groups_P=[trajectory_group],\n        taglist_P=[[\"unit-test\"]],\n        sampling_client_steps_P=[7],\n    )\n\n    record = json.loads(output_path.read_text().strip())\n    assert record[\"iteration\"] == 1\n    assert record[\"sampling_client_step\"] == 7\n    assert record[\"total_reward\"] == 1.0\n    assert record[\"final_reward\"] == 0.75\n    assert record[\"steps\"][0][\"reward\"] == 0.25\n    assert record[\"steps\"][0][\"metrics\"][\"score\"] == 1.5\n    assert record[\"steps\"][0][\"logs\"][\"rank\"] == 2\n"
  },
  {
    "path": "tinker_cookbook/rl/rollout_strategy.py",
    "content": "\"\"\"Pluggable strategies for collecting trajectories within a rollout group.\n\nA :class:`RolloutStrategy` decides *how* to run N single-rollout coroutines\nin parallel — whether to fail fast, retry on failure, etc.  It owns the full\ntrajectory collection lifecycle including env creation (via\n``EnvGroupBuilder.make_envs()``), so strategies like retry can create fresh\nenvs as needed.\n\nGroup reward computation and logging remain in\n:func:`~tinker_cookbook.rl.rollouts.do_group_rollout`.\n\nImplementations must be pickleable (frozen dataclasses with primitive fields)\nbecause they are bundled into ``_RolloutTask`` for cross-process dispatch.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport logging\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\n\nfrom tinker_cookbook.completers import TokenCompleter\nfrom tinker_cookbook.exceptions import ConfigurationError\nfrom tinker_cookbook.rl.types import Env, EnvGroupBuilder, RolloutError, Trajectory\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass(frozen=True)\nclass RolloutResult:\n    \"\"\"Output of a :class:`RolloutStrategy`.\"\"\"\n\n    trajectories: list[Trajectory]\n    envs: Sequence[Env]\n    errors: list[RolloutError]\n\n\nclass RolloutStrategy(ABC):\n    \"\"\"Controls how trajectories are collected from a group of environments.\n\n    Subclasses implement :meth:`execute` which receives the\n    :class:`EnvGroupBuilder` and a policy, creates envs, runs rollouts,\n    and returns the surviving trajectories plus any error info.\n\n    Implementations must be pickleable — use ``@dataclass(frozen=True)``\n    with only primitive fields.\n    \"\"\"\n\n    @property\n    def catches_group_errors(self) -> bool:\n        \"\"\"If True, group-level errors (``make_envs``, ``compute_group_rewards``)\n        are caught and the group is skipped.  If False, they propagate.\"\"\"\n        return False\n\n    @abstractmethod\n    async def execute(\n        self,\n        env_group_builder: EnvGroupBuilder,\n        policy: TokenCompleter,\n    ) -> RolloutResult:\n        \"\"\"Create envs, run rollouts, and return results.\n\n        May raise on unrecoverable errors (e.g. retry budget exhausted).\n        The caller (:func:`do_group_rollout`) handles group-level error\n        recovery based on :attr:`catches_group_errors`.\n        \"\"\"\n        ...\n\n\n@dataclass(frozen=True)\nclass FailFast(RolloutStrategy):\n    \"\"\"Default strategy: any trajectory error crashes the group.\n\n    Produces identical behaviour to the original ``asyncio.gather(...)``\n    path — no error tolerance, no overhead.\n    \"\"\"\n\n    async def execute(\n        self,\n        env_group_builder: EnvGroupBuilder,\n        policy: TokenCompleter,\n    ) -> RolloutResult:\n        from tinker_cookbook.rl.rollouts import do_single_rollout\n\n        envs = await env_group_builder.make_envs()\n        trajectories: list[Trajectory] = list(\n            await asyncio.gather(*[do_single_rollout(policy, env) for env in envs])\n        )\n        return RolloutResult(trajectories=trajectories, envs=envs, errors=[])\n\n\n@dataclass(frozen=True)\nclass RetryOnFailure(RolloutStrategy):\n    \"\"\"Retry failed trajectories with fresh environments.\n\n    When a trajectory fails (container crash, sandbox flake, transient error),\n    a fresh env is created via ``make_envs()`` and the rollout is retried.\n    This continues until either all trajectories succeed or the retry budget\n    is exhausted.\n\n    If the retry budget is exhausted and a failure still occurs, the remaining\n    in-flight tasks are cancelled and the exception is re-raised. This avoids\n    partial-group bias from training on an incomplete set of trajectories.\n\n    Uses ``asyncio.wait(FIRST_COMPLETED)`` so retries start as soon as a\n    failure is detected, without waiting for other in-flight rollouts.\n\n    Args:\n        max_retries: Total retry budget across all trajectories in the group.\n            For example, with ``max_retries=3`` and a group of 8 envs, up to\n            3 individual trajectory failures will be retried.\n    \"\"\"\n\n    max_retries: int = 3\n\n    @property\n    def catches_group_errors(self) -> bool:\n        return True\n\n    async def execute(\n        self,\n        env_group_builder: EnvGroupBuilder,\n        policy: TokenCompleter,\n    ) -> RolloutResult:\n        from tinker_cookbook.rl.rollouts import do_single_rollout\n\n        envs = await env_group_builder.make_envs()\n\n        # Map task -> env for tracking\n        task_to_env: dict[asyncio.Task[Trajectory], Env] = {}\n        for env in envs:\n            task = asyncio.create_task(do_single_rollout(policy, env))\n            task_to_env[task] = env\n\n        trajectories: list[Trajectory] = []\n        surviving_envs: list[Env] = []\n        errors: list[RolloutError] = []\n        retries_remaining = self.max_retries\n        pending: set[asyncio.Task[Trajectory]] = set(task_to_env.keys())\n\n        while pending:\n            done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)\n            for task in done:\n                try:\n                    traj = task.result()\n                    trajectories.append(traj)\n                    surviving_envs.append(task_to_env[task])\n                except (asyncio.CancelledError, KeyboardInterrupt):\n                    # Never swallow cancellation — cancel remaining and propagate\n                    for t in pending:\n                        t.cancel()\n                    await asyncio.gather(*pending, return_exceptions=True)\n                    raise\n                except Exception as exc:\n                    logger.warning(\n                        \"Trajectory failed (%s): %s (retries_remaining=%d)\",\n                        type(exc).__name__,\n                        exc,\n                        retries_remaining,\n                    )\n                    errors.append(\n                        RolloutError(\n                            error_type=type(exc).__name__,\n                            error_message=str(exc),\n                        )\n                    )\n                    if retries_remaining > 0:\n                        retries_remaining -= 1\n                        # Create a fresh env for retry.\n                        # Note: make_envs() creates a full group but we only need one.\n                        # The extras are cheap Python objects for most envs; for sandbox\n                        # envs the unused containers get GC'd.\n                        new_envs = await env_group_builder.make_envs()\n                        new_env = new_envs[0]\n                        new_task = asyncio.create_task(do_single_rollout(policy, new_env))\n                        task_to_env[new_task] = new_env\n                        pending.add(new_task)\n                    else:\n                        # Budget exhausted — cancel remaining and re-raise.\n                        # This avoids partial-group bias from training on an\n                        # incomplete group of trajectories.\n                        logger.error(\n                            \"Retry budget exhausted (%d retries), cancelling remaining tasks\",\n                            self.max_retries,\n                        )\n                        for t in pending:\n                            t.cancel()\n                        await asyncio.gather(*pending, return_exceptions=True)\n                        raise exc\n\n        return RolloutResult(\n            trajectories=trajectories,\n            envs=surviving_envs,\n            errors=errors,\n        )\n\n\n# ---------------------------------------------------------------------------\n# Config mapping\n# ---------------------------------------------------------------------------\n\n\ndef rollout_strategy_from_config(\n    rollout_error_tolerance: bool | RolloutStrategy,\n) -> RolloutStrategy:\n    \"\"\"Convert a ``Config.rollout_error_tolerance`` value to a :class:`RolloutStrategy`.\n\n    - ``False`` -> :class:`FailFast` (crash on any error, the default)\n    - ``True``  -> :class:`RetryOnFailure` with default ``max_retries=3``\n    - A :class:`RolloutStrategy` instance -> passed through as-is\n    \"\"\"\n    if isinstance(rollout_error_tolerance, RolloutStrategy):\n        return rollout_error_tolerance\n    if rollout_error_tolerance is False:\n        return FailFast()\n    if rollout_error_tolerance is True:\n        return RetryOnFailure()\n    raise ConfigurationError(f\"Invalid rollout_error_tolerance value: {rollout_error_tolerance!r}\")\n"
  },
  {
    "path": "tinker_cookbook/rl/rollouts.py",
    "content": "import asyncio\nimport logging\nimport numbers\nfrom collections import Counter\nfrom concurrent.futures import Executor\nfrom contextvars import ContextVar\nfrom dataclasses import dataclass, field\nfrom typing import Any\n\nimport tinker\n\nfrom tinker_cookbook.completers import TinkerTokenCompleter, TokenCompleter\nfrom tinker_cookbook.exceptions import AllTrajectoriesFailedError\nfrom tinker_cookbook.rl.rollout_strategy import FailFast, RolloutStrategy\nfrom tinker_cookbook.rl.types import (\n    Env,\n    EnvGroupBuilder,\n    Trajectory,\n    TrajectoryGroup,\n    Transition,\n)\nfrom tinker_cookbook.utils import logtree, trace\nfrom tinker_cookbook.utils.misc_utils import all_same\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass RolloutErrorCounter:\n    \"\"\"Accumulates rollout error counts from :class:`TrajectoryGroup` results.\n\n    Lives in the main event loop only — never crosses thread/process boundaries.\n    Error information reaches the counter via :attr:`TrajectoryGroup.rollout_errors`,\n    which is embedded in the return value (pickleable, safe for any executor).\n    \"\"\"\n\n    _counts: Counter[str] = field(default_factory=Counter)\n    _groups_skipped: int = 0\n\n    def ingest(self, result: TrajectoryGroup | None) -> None:\n        \"\"\"Absorb error info from a single rollout result.\"\"\"\n        if result is None:\n            self._groups_skipped += 1\n            return\n        for err in result.rollout_errors:\n            self._counts[err.error_type] += 1\n\n    def get_metrics(self, prefix: str = \"rollout_errors\") -> dict[str, float]:\n        \"\"\"Return cumulative error metrics (monotonically increasing).\"\"\"\n        out: dict[str, float] = {}\n        if self._counts or self._groups_skipped > 0:\n            for k, v in self._counts.items():\n                out[f\"{prefix}/{k}\"] = float(v)\n            out[f\"{prefix}/total\"] = float(sum(self._counts.values()))\n            out[f\"{prefix}/groups_skipped\"] = float(self._groups_skipped)\n        return out\n\n\ndef _log_transition_logs(logs: dict[str, Any]) -> None:\n    \"\"\"Render transition logs in a readable structure without truncating table cells.\"\"\"\n    if not logs:\n        return\n    with logtree.scope_header(\"Diagnostics\"):\n        for key, value in logs.items():\n            text = str(value)\n            if \"\\n\" in text or len(text) > 120:\n                logtree.details(text, summary=key, pre=True)\n            else:\n                logtree.log_text(f\"{key}: {text}\")\n\n\ndef _log_transition_metrics(metrics: dict[str, Any] | None) -> None:\n    \"\"\"Render transition metrics in a compact, always-visible table.\"\"\"\n    if not metrics:\n        return\n    formatted_metrics = {}\n    for key, value in metrics.items():\n        if isinstance(value, numbers.Real):\n            formatted_metrics[key] = f\"{float(value):.3f}\"\n        else:\n            formatted_metrics[key] = str(value)\n    with logtree.scope_header(\"Step Metrics\"):\n        logtree.table_from_dict(\n            formatted_metrics,\n            caption=\"Metrics emitted by env.step\",\n        )\n\n\ndef _log_single_trajectory_details(traj: Trajectory, final_reward: float) -> None:\n    with logtree.scope_header(\"Episode Details\"):\n        for turn_idx, transition in enumerate(traj.transitions, start=1):\n            with logtree.scope_header(f\"Turn {turn_idx}\"):\n                logtree.table_from_dict(\n                    {\n                        \"ob_len\": transition.ob.length,\n                        \"ac_len\": len(transition.ac.tokens),\n                        \"step_reward\": f\"{transition.reward:.3f}\",\n                    },\n                    caption=\"Step stats\",\n                )\n                _log_transition_metrics(transition.metrics)\n                _log_transition_logs(transition.logs)\n\n        logtree.table_from_dict(\n            {\n                \"num_turns\": len(traj.transitions),\n                \"final_ob_len\": traj.final_ob.length,\n                \"sum_step_rewards\": f\"{sum(t.reward for t in traj.transitions):.3f}\",\n                \"final_group_reward\": f\"{final_reward:.3f}\",\n                \"total_return\": f\"{sum(t.reward for t in traj.transitions) + final_reward:.3f}\",\n            },\n            caption=\"Episode totals\",\n        )\n\n\nasync def do_single_rollout(policy: TokenCompleter, env: Env) -> Trajectory:\n    \"\"\"Run a single rollout (env episode). Env logging (if any) goes into\n    whatever logtree scope the caller has set up.\"\"\"\n    transitions = []\n    async with trace.scope_span(\"env_initial_observation\"):\n        ob, stop_condition = await env.initial_observation()\n    while True:\n        async with trace.scope_span(\"policy_sample\"):\n            ac_with_logprobs = await policy(ob, stop_condition)\n        async with trace.scope_span(\"env_step\"):\n            step_result = await env.step(ac_with_logprobs.tokens)\n        transition = Transition(\n            ob=ob,\n            ac=ac_with_logprobs,\n            reward=step_result.reward,\n            episode_done=step_result.episode_done,\n            metrics=step_result.metrics,\n            logs=step_result.logs,\n        )\n        transitions.append(transition)\n        ob = step_result.next_observation\n        stop_condition = step_result.next_stop_condition\n        if step_result.episode_done:\n            break\n    return Trajectory(transitions=transitions, final_ob=ob)\n\n\n@logtree.scope_header_decorator(\"Group Rollout\")\nasync def do_group_rollout(\n    env_group_builder: EnvGroupBuilder,\n    policy: TokenCompleter,\n    strategy: RolloutStrategy | None = None,\n) -> TrajectoryGroup:\n    \"\"\"Run rollouts for all environments in a group and compute group rewards.\n\n    Args:\n        strategy: Controls how trajectories are collected (error handling,\n            retries, etc.).  Defaults to :class:`FailFast` which preserves\n            the original fail-on-any-error behaviour.\n    \"\"\"\n    if strategy is None:\n        strategy = FailFast()\n    try:\n        result = await strategy.execute(env_group_builder, policy)\n\n        async with trace.scope_span(\"compute_group_rewards\"):\n            rewards_and_metrics_G = await env_group_builder.compute_group_rewards(\n                result.trajectories, result.envs\n            )\n        rewards_G, metrics_G = zip(*rewards_and_metrics_G, strict=True)\n\n        with logtree.scope_header(\"Trajectory Details\"):\n            for traj_idx, (traj, final_reward) in enumerate(\n                zip(result.trajectories, rewards_G, strict=True)\n            ):\n                with logtree.scope_header(f\"Trajectory {traj_idx} Episode\"):\n                    _log_single_trajectory_details(traj, final_reward)\n\n        return TrajectoryGroup(\n            result.trajectories, list(rewards_G), list(metrics_G), rollout_errors=result.errors\n        )\n    finally:\n        # cleanup() is not wrapped in try/except; implementations must handle failures\n        # internally and not raise, or exceptions here will mask rollout errors.\n        await env_group_builder.cleanup()\n\n\n# ---------------------------------------------------------------------------\n# Rollout executor — allows offloading group rollouts to processes/Ray/etc.\n# ---------------------------------------------------------------------------\n\n_rollout_executor: ContextVar[Executor | None] = ContextVar(\"rollout_executor\", default=None)\n\n\ndef set_rollout_executor(executor: Executor | None) -> None:\n    \"\"\"Set the executor used for group rollouts.\n\n    When set, ``do_group_rollout_and_filter_constant_reward`` dispatches each\n    rollout via ``loop.run_in_executor(executor, ...)`` instead of running it\n    as an asyncio coroutine in the current process.\n\n    Pass any ``concurrent.futures.Executor`` — ``ProcessPoolExecutor`` works\n    out of the box, or wrap Ray / custom cluster dispatchers as ``Executor``.\n\n    Pass ``None`` to revert to the default in-process async behavior.\n    \"\"\"\n    _rollout_executor.set(executor)\n\n\ndef get_rollout_executor() -> Executor | None:\n    \"\"\"Get the current rollout executor (None = in-process async).\"\"\"\n    return _rollout_executor.get()\n\n\n@dataclass(frozen=True)\nclass _RolloutTask:\n    \"\"\"Pickleable bundle of inputs for cross-process rollout dispatch.\"\"\"\n\n    sampling_client: tinker.SamplingClient\n    env_group_builder: EnvGroupBuilder\n    max_tokens: int\n    temperature: float\n    remove_constant_reward_groups: bool\n    enable_logging: bool\n    strategy: RolloutStrategy = field(default_factory=FailFast)\n\n\ndef _run_rollout_sync(task: _RolloutTask) -> TrajectoryGroup | None:\n    \"\"\"Entry point for executor workers. Runs the async rollout in a fresh event loop.\n\n    Called by ``loop.run_in_executor()`` — must be a module-level sync function\n    so it can be pickled for ``ProcessPoolExecutor``.\n    \"\"\"\n    return asyncio.run(\n        _do_group_rollout_and_filter_constant_reward_impl(\n            task.sampling_client,\n            task.env_group_builder,\n            task.max_tokens,\n            task.temperature,\n            task.remove_constant_reward_groups,\n            task.enable_logging,\n            strategy=task.strategy,\n        )\n    )\n\n\n@trace.scope\nasync def do_group_rollout_and_filter_constant_reward(\n    sampling_client: tinker.SamplingClient,\n    env_group_builder: EnvGroupBuilder,\n    max_tokens: int,\n    temperature: float,\n    do_remove_constant_reward_groups: bool,\n    enable_logging: bool = True,\n    strategy: RolloutStrategy | None = None,\n) -> TrajectoryGroup | None:\n    \"\"\"Run a group rollout, optionally dispatching to an external executor.\n\n    When a rollout executor is set (via ``set_rollout_executor``), inputs are\n    bundled into a pickleable ``_RolloutTask`` and dispatched via\n    ``loop.run_in_executor()``. Otherwise, runs as an asyncio coroutine\n    in the current process (zero overhead).\n\n    Args:\n        strategy: Controls how trajectories are collected within the group\n            (error handling, retries, etc.).  Defaults to :class:`FailFast`.\n    \"\"\"\n    if strategy is None:\n        strategy = FailFast()\n\n    executor = get_rollout_executor()\n    if executor is not None:\n        task = _RolloutTask(\n            sampling_client=sampling_client,\n            env_group_builder=env_group_builder,\n            max_tokens=max_tokens,\n            temperature=temperature,\n            remove_constant_reward_groups=do_remove_constant_reward_groups,\n            enable_logging=enable_logging,\n            strategy=strategy,\n        )\n        loop = asyncio.get_running_loop()\n        return await loop.run_in_executor(executor, _run_rollout_sync, task)\n\n    return await _do_group_rollout_and_filter_constant_reward_impl(\n        sampling_client,\n        env_group_builder,\n        max_tokens,\n        temperature,\n        do_remove_constant_reward_groups,\n        enable_logging,\n        strategy=strategy,\n    )\n\n\nasync def _do_group_rollout_and_filter_constant_reward_impl(\n    sampling_client: tinker.SamplingClient,\n    env_group_builder: EnvGroupBuilder,\n    max_tokens: int,\n    temperature: float,\n    do_remove_constant_reward_groups: bool,\n    enable_logging: bool = True,\n    strategy: RolloutStrategy | None = None,\n) -> TrajectoryGroup | None:\n    if strategy is None:\n        strategy = FailFast()\n\n    policy = TinkerTokenCompleter(sampling_client, max_tokens=max_tokens, temperature=temperature)\n\n    try:\n        with logtree.optional_enable_logging(enable_logging):\n            trajectory_group = await do_group_rollout(\n                env_group_builder,\n                policy,\n                strategy=strategy,\n            )\n    except AllTrajectoriesFailedError as e:\n        # All retries exhausted — already logged per-trajectory inside the strategy\n        logger.warning(str(e))\n        return None\n    except Exception as e:\n        if not strategy.catches_group_errors:\n            raise\n        logger.warning(f\"Rollout error ({type(e).__name__}), skipping group: {e}\")\n        return None\n\n    # Remove if all trajectories have the same reward\n    if do_remove_constant_reward_groups and all_same(trajectory_group.get_total_rewards()):\n        return None\n    return trajectory_group\n"
  },
  {
    "path": "tinker_cookbook/rl/shutdown_test.py",
    "content": "\"\"\"\nTests for the cascading shutdown mechanism in async RL training.\n\nThese tests validate that when the dataloader exhausts its data, the shutdown\npropagates cleanly through the pipeline without hanging:\n  dataloader -> workers -> training loop -> evaluation loop\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\n\nfrom tinker_cookbook.rl.train import _AsyncCounter, _Shutdown\n\n\nclass TestAsyncCounter:\n    def test_decrement_and_get(self):\n        async def _test():\n            counter = _AsyncCounter(3)\n            assert await counter.decrement_and_get() == 2\n            assert await counter.decrement_and_get() == 1\n            assert await counter.decrement_and_get() == 0\n\n        asyncio.run(_test())\n\n    def test_concurrent_decrements(self):\n        \"\"\"Multiple concurrent decrements should each see a unique value.\"\"\"\n\n        async def _test():\n            counter = _AsyncCounter(100)\n            results = await asyncio.gather(*[counter.decrement_and_get() for _ in range(100)])\n            # Each decrement should produce a unique value from 0 to 99\n            assert sorted(results) == list(range(100))\n\n        asyncio.run(_test())\n\n\nclass TestShutdownCascade:\n    def test_dataloader_enqueues_shutdown_sentinels(self):\n        \"\"\"When the dataloader finishes, it should enqueue one _Shutdown per worker.\"\"\"\n\n        async def _test():\n            num_workers = 4\n            queue: asyncio.Queue[str | _Shutdown] = asyncio.Queue(maxsize=num_workers)\n\n            for _ in range(num_workers):\n                await queue.put(_Shutdown())\n\n            for _ in range(num_workers):\n                item = await queue.get()\n                assert isinstance(item, _Shutdown)\n\n            assert queue.empty()\n\n        asyncio.run(_test())\n\n    def test_last_worker_signals_training_loop(self):\n        \"\"\"The last worker to exit should enqueue a _Shutdown to the training queue.\"\"\"\n\n        async def _test():\n            num_workers = 3\n            counter = _AsyncCounter(num_workers)\n            training_queue: asyncio.Queue[str | _Shutdown] = asyncio.Queue()\n\n            for _ in range(num_workers):\n                num_alive = await counter.decrement_and_get()\n                if num_alive == 0:\n                    training_queue.put_nowait(_Shutdown())\n\n            assert training_queue.qsize() == 1\n            assert isinstance(await training_queue.get(), _Shutdown)\n\n        asyncio.run(_test())\n\n    def test_full_cascade_no_hang(self):\n        \"\"\"\n        Full integration test: wire up all four loops with mock rollouts and verify\n        the entire pipeline shuts down cleanly without hanging.\n        \"\"\"\n\n        async def _test():\n            num_workers = 2\n            num_batches = 2\n            items_per_batch = 2\n\n            env_queue: asyncio.Queue[int | _Shutdown] = asyncio.Queue(maxsize=num_workers)\n            trajectory_queue: asyncio.Queue[int | _Shutdown | None] = asyncio.Queue()\n            dataloader_done = asyncio.Event()\n            eval_should_shutdown = asyncio.Event()\n            worker_counter = _AsyncCounter(num_workers)\n            sampling_updated = asyncio.Event()\n            sampling_updated.set()\n\n            loops_completed: list[str] = []\n\n            async def dataloader_loop():\n                for batch_idx in range(num_batches):\n                    for item_idx in range(items_per_batch):\n                        await env_queue.put(batch_idx * items_per_batch + item_idx)\n                dataloader_done.set()\n                for _ in range(num_workers):\n                    await env_queue.put(_Shutdown())\n                loops_completed.append(\"dataloader\")\n\n            async def worker_loop():\n                while True:\n                    item = await env_queue.get()\n                    if isinstance(item, _Shutdown):\n                        break\n                    await asyncio.sleep(0.01)\n                    trajectory_queue.put_nowait(item)\n                num_alive = await worker_counter.decrement_and_get()\n                if num_alive == 0:\n                    trajectory_queue.put_nowait(_Shutdown())\n                loops_completed.append(\"worker\")\n\n            async def training_loop():\n                items_consumed = 0\n                target = num_batches * items_per_batch\n                while items_consumed < target:\n                    item = await trajectory_queue.get()\n                    if isinstance(item, _Shutdown):\n                        break\n                    if item is None:\n                        continue\n                    items_consumed += 1\n                    sampling_updated.set()\n                eval_should_shutdown.set()\n                sampling_updated.set()\n                loops_completed.append(\"training\")\n\n            async def evaluation_loop():\n                while not eval_should_shutdown.is_set():\n                    await sampling_updated.wait()\n                    sampling_updated.clear()\n                loops_completed.append(\"evaluation\")\n\n            await asyncio.wait_for(\n                asyncio.gather(\n                    dataloader_loop(),\n                    *[worker_loop() for _ in range(num_workers)],\n                    training_loop(),\n                    evaluation_loop(),\n                ),\n                timeout=5.0,\n            )\n\n            assert \"dataloader\" in loops_completed\n            assert loops_completed.count(\"worker\") == num_workers\n            assert \"training\" in loops_completed\n            assert \"evaluation\" in loops_completed\n\n        asyncio.run(_test())\n\n    def test_cascade_with_early_shutdown(self):\n        \"\"\"\n        When the dataloader has fewer items than the training loop expects,\n        the _Shutdown sentinel should still propagate and prevent hanging.\n        \"\"\"\n\n        async def _test():\n            num_workers = 2\n            num_dataloader_batches = 1\n            items_per_batch = 2\n            training_loop_target = 10  # Expects more than dataloader provides\n\n            env_queue: asyncio.Queue[int | _Shutdown] = asyncio.Queue(maxsize=num_workers)\n            trajectory_queue: asyncio.Queue[int | _Shutdown | None] = asyncio.Queue()\n            eval_should_shutdown = asyncio.Event()\n            worker_counter = _AsyncCounter(num_workers)\n            sampling_updated = asyncio.Event()\n            sampling_updated.set()\n\n            async def dataloader_loop():\n                for batch_idx in range(num_dataloader_batches):\n                    for item_idx in range(items_per_batch):\n                        await env_queue.put(batch_idx * items_per_batch + item_idx)\n                for _ in range(num_workers):\n                    await env_queue.put(_Shutdown())\n\n            async def worker_loop():\n                while True:\n                    item = await env_queue.get()\n                    if isinstance(item, _Shutdown):\n                        break\n                    trajectory_queue.put_nowait(item)\n                num_alive = await worker_counter.decrement_and_get()\n                if num_alive == 0:\n                    trajectory_queue.put_nowait(_Shutdown())\n\n            async def training_loop():\n                i_batch = 0\n                while i_batch < training_loop_target:\n                    item = await trajectory_queue.get()\n                    if isinstance(item, _Shutdown):\n                        break\n                    if item is None:\n                        continue\n                    i_batch += 1\n                eval_should_shutdown.set()\n                sampling_updated.set()\n\n            async def evaluation_loop():\n                while not eval_should_shutdown.is_set():\n                    await sampling_updated.wait()\n                    sampling_updated.clear()\n\n            # Should not hang — shutdown cascade terminates all loops\n            await asyncio.wait_for(\n                asyncio.gather(\n                    dataloader_loop(),\n                    *[worker_loop() for _ in range(num_workers)],\n                    training_loop(),\n                    evaluation_loop(),\n                ),\n                timeout=5.0,\n            )\n\n        asyncio.run(_test())\n\n    def test_requeue_skipped_during_shutdown(self):\n        \"\"\"\n        When the dataloader is done, stale samples should be discarded\n        rather than requeued (to avoid deadlocking on a full bounded queue).\n        \"\"\"\n        dataloader_done = asyncio.Event()\n\n        requeue_attempted = False\n        discard_count = 0\n\n        def filter_stale(is_stale: bool) -> bool:\n            nonlocal requeue_attempted, discard_count\n            if is_stale:\n                if dataloader_done.is_set():\n                    discard_count += 1\n                else:\n                    requeue_attempted = True\n                return False\n            return True\n\n        # Before dataloader is done: stale items should attempt requeue\n        filter_stale(is_stale=True)\n        assert requeue_attempted\n\n        # After dataloader is done: stale items should be discarded\n        requeue_attempted = False\n        dataloader_done.set()\n        filter_stale(is_stale=True)\n        assert not requeue_attempted\n        assert discard_count == 1\n\n    def test_none_items_pass_through_during_shutdown(self):\n        \"\"\"\n        None items (failed rollouts) should be skipped, and _Shutdown should\n        still be received even if preceded by None items.\n        \"\"\"\n\n        async def _test():\n            queue: asyncio.Queue[int | _Shutdown | None] = asyncio.Queue()\n\n            queue.put_nowait(None)\n            queue.put_nowait(None)\n            queue.put_nowait(42)\n            queue.put_nowait(None)\n            queue.put_nowait(_Shutdown())\n\n            received_items = []\n            while True:\n                item = await queue.get()\n                if isinstance(item, _Shutdown):\n                    break\n                if item is None:\n                    continue\n                received_items.append(item)\n\n            assert received_items == [42]\n\n        asyncio.run(_test())\n"
  },
  {
    "path": "tinker_cookbook/rl/train.py",
    "content": "\"\"\"\nImplements RL on general MDPs\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport io\nimport logging\nimport re\nimport time\nfrom collections.abc import Callable, Coroutine, Iterable, Iterator, Sequence\nfrom concurrent.futures import Executor\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, TypeVar\n\nimport chz\nimport numpy as np\nimport tinker\nimport torch\nfrom tinker.types import LossFnType\nfrom tqdm import tqdm\n\nfrom tinker_cookbook import checkpoint_utils, model_info\nfrom tinker_cookbook.display import colorize_example\nfrom tinker_cookbook.eval.evaluators import SamplingClientEvaluator, SamplingClientEvaluatorBuilder\nfrom tinker_cookbook.exceptions import ConfigurationError\nfrom tinker_cookbook.rl.data_processing import (\n    assemble_training_data,\n    compute_advantages,\n    remove_constant_reward_groups,\n)\nfrom tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics\nfrom tinker_cookbook.rl.metrics import (\n    compute_kl_sample_train,\n    compute_post_kl,\n    compute_sampling_client_metrics,\n    incorporate_kl_penalty,\n)\nfrom tinker_cookbook.rl.rollout_logging import (\n    RolloutSummaryExportConfig,\n    RolloutSummaryGroup,\n    rollout_summaries_jsonl_path,\n    write_rollout_summaries_jsonl_from_groups,\n)\nfrom tinker_cookbook.rl.rollout_strategy import (\n    RolloutStrategy,\n    rollout_strategy_from_config,\n)\nfrom tinker_cookbook.rl.rollouts import (\n    RolloutErrorCounter,\n    do_group_rollout,  # noqa: F401 — re-exported for verifiers monkey-patching\n    do_group_rollout_and_filter_constant_reward,\n    set_rollout_executor,\n)\nfrom tinker_cookbook.rl.types import (\n    EnvGroupBuilder,\n    RLDataset,\n    RLDatasetBuilder,\n    TrajectoryGroup,\n)\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\nfrom tinker_cookbook.utils import logtree, ml_log, trace\nfrom tinker_cookbook.utils.misc_utils import safezip, split_list, timed\n\nlogger = logging.getLogger(__name__)\n\nT = TypeVar(\"T\")\n\n\n@chz.chz\nclass KLReferenceConfig:\n    \"\"\"Configuration for the KL penalty reference model.\n\n    If not specified in Config, the training model's base model is used.\n    \"\"\"\n\n    base_model: str\n    load_checkpoint_path: str | None = None\n\n\nasync def gather_with_progress(\n    coroutines: Iterable[Coroutine[Any, Any, T]],\n    desc: str,\n) -> list[T]:\n    \"\"\"\n    Run coroutines concurrently with a progress bar that updates as each completes.\n\n    This preserves the order of results (like asyncio.gather) while providing\n    real-time progress feedback as individual coroutines complete.\n    \"\"\"\n    coroutine_list = list(coroutines)\n    pbar = tqdm(total=len(coroutine_list), desc=desc)\n\n    async def track(coro: Coroutine[Any, Any, T]) -> T:\n        result = await coro\n        pbar.update(1)\n        return result\n\n    try:\n        results = await asyncio.gather(*[track(coro) for coro in coroutine_list])\n    finally:\n        pbar.close()\n\n    return results\n\n\ndef _get_evaluator_name(evaluator: SamplingClientEvaluator) -> str:\n    return (\n        evaluator.name\n        if isinstance(evaluator, RLTestSetEvaluator) and evaluator.name is not None\n        else \"\"\n    )\n\n\ndef _sanitize_filename_component(text: str) -> str:\n    \"\"\"Make a safe filename component.\"\"\"\n    sanitized = re.sub(r\"[^A-Za-z0-9_.-]+\", \"_\", text)\n    return sanitized.strip(\"._\") or \"unnamed\"\n\n\ndef _maybe_export_rollout_summary_jsonl(\n    *,\n    cfg: Config,\n    file_prefix: str,\n    split: str,\n    iteration: int,\n    groups_P: Sequence[RolloutSummaryGroup],\n) -> None:\n    \"\"\"\n    Write per-trajectory rollout summaries for one train/eval pass when enabled.\n\n    This is a thin policy gate around rollout_logging utilities:\n    - path naming (`<file_prefix>_rollout_summaries.jsonl`)\n    - on/off switch (`cfg.rollout_json_export`)\n    \"\"\"\n    if not cfg.rollout_json_export:\n        return\n    write_rollout_summaries_jsonl_from_groups(\n        rollout_summaries_jsonl_path(cfg.log_path, file_prefix),\n        split=split,\n        iteration=iteration,\n        groups_P=groups_P,\n    )\n\n\n_LOGTREE_EXPLANATION = (\n    \"This HTML log was generated by logtree during RL training. \"\n    \"It shows rollouts and rewards for a subset of trajectory groups in this iteration. \"\n    \"To customize what gets logged, modify the logtree calls in your Env implementation \"\n    \"(see examples in tinker_cookbook/recipes/).\"\n)\n\n\n@contextmanager\ndef _get_logtree_scope(\n    log_path: str | None, num_groups_to_log: int, f_name: str, scope_name: str\n) -> Iterator[None]:\n    \"\"\"\n    Creates a context manager; all log inside this context will be logged under the section `scope_name`.\n    It will create files with the paths log_path/f_name.html and log_path/f_name_logtree.json.\n    If num_groups_to_log is 0, it will disable logging (but note that this function does not actually implement the logic for logging itself!)\n    \"\"\"\n    if log_path is None or num_groups_to_log <= 0:\n        yield\n        return\n\n    logtree_path = str(Path(log_path) / f\"{f_name}.html\")\n    logtree_json_path = str(Path(log_path) / f\"{f_name}_logtree.json\")\n    trace = None\n    try:\n        with logtree.init_trace(scope_name, path=logtree_path) as trace:\n            logtree.log_text(_LOGTREE_EXPLANATION)\n            yield\n    finally:\n        if trace is not None:\n            logtree.write_trace_json(trace, logtree_json_path)\n\n\n@trace.scope\ndef _select_representative_inds(scores: list[float], num_inds: int) -> list[int]:\n    assert num_inds <= len(scores)\n    sorted_inds = np.argsort(scores)\n    uniform_inds = np.linspace(0, len(sorted_inds) - 1, num_inds).astype(int)\n    return [int(sorted_inds[i]) for i in uniform_inds]\n\n\n@trace.scope\ndef print_group(traj_group: TrajectoryGroup, tokenizer: Tokenizer):\n    \"\"\"\n    Print a subset of the trajectory group to the console.\n    \"\"\"\n    # Cut down the number of trajectories to print\n    max_trajs_to_print = 4\n    if len(traj_group.trajectories_G) > max_trajs_to_print:\n        inds = _select_representative_inds(traj_group.get_total_rewards(), max_trajs_to_print)\n        traj_group = TrajectoryGroup(\n            trajectories_G=[traj_group.trajectories_G[i] for i in inds],\n            final_rewards_G=[traj_group.final_rewards_G[i] for i in inds],\n            metrics_G=[traj_group.metrics_G[i] for i in inds],\n        )\n\n    rewards = traj_group.get_total_rewards()\n    advantages_G = compute_advantages([traj_group])\n    data_D, metadata_D = assemble_training_data([traj_group], advantages_G)\n\n    buf = io.StringIO()\n\n    @trace.scope\n    def bprint(s: str):\n        print(s, file=buf)\n\n    bprint(\"\\n====== Trajectory Group ======\")\n    last_metadata = None\n    for datum, metadata in safezip(data_D, metadata_D):\n        idx = metadata[\"traj_idx\"]\n        if metadata != last_metadata:\n            bprint(f\"****** trajectory idx={idx}, reward={rewards[idx]:.3g} ******\")\n            # Print trajectory-level metrics\n            if traj_group.metrics_G[idx]:\n                bprint(\"Trajectory metrics:\")\n                for key, value in traj_group.metrics_G[idx].items():\n                    bprint(f\"  {key}: {value}\")\n            # Print per-transition metrics\n            transition_metrics = [\n                transition.metrics\n                for transition in traj_group.trajectories_G[idx].transitions\n                if transition.metrics\n            ]\n            if transition_metrics:\n                bprint(\"Per-step metrics:\")\n                for i, metrics in enumerate(transition_metrics):\n                    bprint(f\"  Step {i}:\")\n                    for key, value in metrics.items():\n                        bprint(f\"    {key}: {value}\")\n        bprint(\"---- datum ----\")\n        bprint(colorize_example(datum, tokenizer, key=\"advantages\"))\n        last_metadata = metadata\n    bprint(\"====== End Trajectory Group ======\")\n    logger.info(buf.getvalue().rstrip())\n\n\ndef _remove_mask(datum: tinker.Datum) -> tinker.Datum:\n    return tinker.Datum(\n        model_input=datum.model_input,\n        loss_fn_inputs={k: v for k, v in datum.loss_fn_inputs.items() if k != \"mask\"},\n    )\n\n\ndef _training_logprobs_from_fwd_bwd(\n    fwd_bwd_result: tinker.ForwardBackwardOutput,\n) -> list[torch.Tensor]:\n    return [output[\"logprobs\"].to_torch() for output in fwd_bwd_result.loss_fn_outputs]\n\n\n@trace.scope\nasync def train_step(\n    data_D: list[tinker.Datum],\n    training_client: tinker.TrainingClient,\n    learning_rate: float,\n    num_substeps: int,\n    loss_fn: LossFnType,\n    loss_fn_config: dict[str, Any] | None = None,\n    metrics: dict[str, Any] | None = None,\n) -> list[torch.Tensor]:\n    \"\"\"Train the model on collected trajectories.\n\n    Pipelines forward_backward and optim_step so they land on the same clock cycle.\n    \"\"\"\n    batches = split_list(data_D, min(num_substeps, len(data_D)))\n    if not batches:\n        return []\n\n    adam_params = tinker.AdamParams(learning_rate=learning_rate, beta1=0.9, beta2=0.95, eps=1e-8)\n    training_logprobs_D: list[torch.Tensor] = []\n    optim_result: tinker.OptimStepResponse | None = None\n\n    # Enqueue first batch\n    fwd_bwd_future = await training_client.forward_backward_async(\n        [_remove_mask(d) for d in batches[0]], loss_fn=loss_fn, loss_fn_config=loss_fn_config\n    )\n    optim_future = await training_client.optim_step_async(adam_params)\n\n    for i in range(len(batches)):\n        # Enqueue next batch before consuming current results (to stay on same clock cycle)\n        if i + 1 < len(batches):\n            next_fwd_bwd_future = await training_client.forward_backward_async(\n                [_remove_mask(d) for d in batches[i + 1]],\n                loss_fn=loss_fn,\n                loss_fn_config=loss_fn_config,\n            )\n            next_optim_future = await training_client.optim_step_async(adam_params)\n        else:\n            next_fwd_bwd_future = None\n            next_optim_future = None\n        # Consume current results\n        fwd_bwd_result = await fwd_bwd_future.result_async()\n        training_logprobs_D.extend(_training_logprobs_from_fwd_bwd(fwd_bwd_result))\n        optim_result = await optim_future.result_async()\n        # Move to next iteration\n        if next_fwd_bwd_future is not None and next_optim_future is not None:\n            fwd_bwd_future = next_fwd_bwd_future\n            optim_future = next_optim_future\n\n    if metrics is not None and optim_result is not None and optim_result.metrics:\n        metrics.update(optim_result.metrics)\n\n    return training_logprobs_D\n\n\n@chz.chz\nclass StreamMinibatchConfig:\n    \"\"\"\n    Configuration for training with minibatch streaming.\n    Once we have accumulated enough trajectories for a minibatch, we will\n    immediately train on them, instead of waiting for the full batch of\n    trajectories to be ready.\n    \"\"\"\n\n    # Total number of trajectory groups across all minibatches and substeps\n    groups_per_batch: int\n    # For each substep, we will divide up the number of trajectory groups\n    # into this many minibatches.\n    # We will do num_minibatches forward_backward() passes and one optim_step()\n    # per substep.\n    num_minibatches: int\n\n\n@chz.chz\nclass AsyncConfig:\n    \"\"\"Configuration for async RL training\"\"\"\n\n    # If samples are generated from a sample more than this many steps ago,\n    # we will skip training on them.\n    max_steps_off_policy: int\n    # We will ensure all batches have at least this many groups, even\n    # as we discard stale samples\n    groups_per_batch: int\n\n\n@chz.chz\nclass Config:\n    \"\"\"Configuration for RL training.\"\"\"\n\n    # -------------------------------------------------------------------------\n    # Core parameters (recommended to set for nearly all runs)\n    # -------------------------------------------------------------------------\n    # Base learning rate used by Adam.\n    learning_rate: float\n    # Builds the RL dataset; also determines number of groups per batch.\n    dataset_builder: RLDatasetBuilder\n    # Model name (base weights) to train.\n    model_name: str\n    # Maximum number of generated tokens per rollout trajectory.\n    max_tokens: int\n    # Directory for checkpoints, logs, and traces.\n    log_path: str = chz.field(munger=lambda _, s: str(Path(s).expanduser()))\n    # Evaluation cadence in training iterations (0 = disabled).\n    eval_every: int = 20\n    # Checkpoint cadence in training iterations (0 = disabled).\n    save_every: int = 20\n    # Optional evaluators run during training.\n    evaluator_builders: list[SamplingClientEvaluatorBuilder] = chz.field(default_factory=list)\n    # Start training from weights at this checkpoint (fresh optimizer state).\n    load_checkpoint_path: str | None = None\n    # Renderer used by the training dataset/environment.\n    renderer_name: str | None = None\n    # Optional W&B project and run name.\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    # -------------------------------------------------------------------------\n    # KL penalty configuration (advanced)\n    # -------------------------------------------------------------------------\n    # KL penalty coefficient against reference policy (0 = disabled).\n    kl_penalty_coef: float = 0.0\n    # Optional position discount for KL penalty terms.\n    kl_discount_factor: float = 0.0\n    # Required when kl_penalty_coef > 0.\n    kl_reference_config: KLReferenceConfig | None = None\n\n    # -------------------------------------------------------------------------\n    # Loss and optimizer behavior (advanced)\n    # -------------------------------------------------------------------------\n    # Loss function and configuration.\n    # See https://tinker-docs.thinkingmachines.ai/losses\n    loss_fn: LossFnType = \"importance_sampling\"\n    loss_fn_config: dict[str, Any] | None = None\n\n    # Number of optimizer steps per training iteration.\n    # Useful for very large batch sizes.\n    num_substeps: int = 1\n    # LoRA rank for the training adapter.\n    lora_rank: int = 32\n\n    # -------------------------------------------------------------------------\n    # Sampling and diagnostics (advanced)\n    # -------------------------------------------------------------------------\n    # Changing sampling temperature is not generally recommended; T=1 is near-optimal\n    # for most post-trained models, and non-1 temperatures currently do not play\n    # well with KL penalty.\n    temperature: float = 1.0\n    # Compute extra post-update KL metrics (adds overhead).\n    compute_post_kl: bool = False\n    # Remove groups where all trajectories have identical reward.\n    remove_constant_reward_groups: bool = False\n    # Tolerance for errors during rollouts (container crashes, sandbox flakes, etc.).\n    # False (default): crash on any error (FailFast).\n    # True: retry failed trajectories with default budget (RetryOnFailure(max_retries=3)).\n    # RolloutStrategy instance: custom strategy (e.g. RetryOnFailure(max_retries=5)).\n    rollout_error_tolerance: bool | RolloutStrategy = False\n    # Emit async trace events for debugging/profiling.\n    enable_trace: bool = False\n    # Save a Gantt chart HTML every N iterations (0 = disabled). Requires plotly.\n    span_chart_every: int = 0\n\n    # -------------------------------------------------------------------------\n    # Execution mode knobs (advanced)\n    # -------------------------------------------------------------------------\n    # Enable async/off-policy training mode when set.\n    async_config: AsyncConfig | None = None\n    # Enable sync training with streaming minibatches when set.\n    stream_minibatch_config: StreamMinibatchConfig | None = None\n    # Optional service base URL override (primarily internal/dev use).\n    base_url: str | None = None\n\n    # -------------------------------------------------------------------------\n    # Checkpoint retention and logging detail (advanced)\n    # -------------------------------------------------------------------------\n    # Periodic checkpoints use this TTL; the final checkpoint is kept indefinitely.\n    # None disables expiry entirely.\n    ttl_seconds: int | None = 604800  # 7 days\n    num_groups_to_log: int = 4  # Number of groups to log per iteration (0 = disable logging)\n    rollout_json_export: bool = True\n\n    # Maximum number of training iterations. If None, train on the full dataset.\n    max_steps: int | None = None\n\n\n@trace.scope\nasync def run_single_evaluation(\n    evaluator: SamplingClientEvaluator,\n    cfg: Config,\n    i_batch: int,\n    sampling_client: tinker.SamplingClient,\n    evaluator_label: str,\n) -> dict[str, Any]:\n    ev_name = _get_evaluator_name(evaluator)\n    eval_file_prefix = f\"eval_{evaluator_label}_iteration_{i_batch:06d}\"\n    with _get_logtree_scope(\n        log_path=cfg.log_path,\n        num_groups_to_log=cfg.num_groups_to_log,\n        f_name=eval_file_prefix,\n        scope_name=f\"Running evaluation {ev_name} {i_batch}\",\n    ):\n        if isinstance(evaluator, RLTestSetEvaluator):\n            rollout_summary_export = (\n                RolloutSummaryExportConfig(\n                    path=rollout_summaries_jsonl_path(cfg.log_path, eval_file_prefix),\n                    split=f\"eval/{evaluator_label}\",\n                    iteration=i_batch,\n                    sampling_client_step=i_batch,\n                )\n                if cfg.rollout_json_export\n                else None\n            )\n            eval_metrics = await evaluator(\n                sampling_client,\n                rollout_summary_export=rollout_summary_export,\n            )\n        else:\n            eval_metrics = await evaluator(sampling_client)\n        return eval_metrics\n\n\n@trace.scope\nasync def run_evaluations_parallel(\n    evaluators: list[SamplingClientEvaluator],\n    sampling_client: tinker.SamplingClient,\n    cfg: Config,\n    i_batch: int,\n) -> dict[str, Any]:\n    \"\"\"Run all evaluators in parallel and return aggregated metrics.\"\"\"\n\n    # Create tasks for all evaluators with names for better traceability\n    tasks = []\n    for i, evaluator in enumerate(evaluators):\n        ev_name = _get_evaluator_name(evaluator)\n        evaluator_label = _sanitize_filename_component(ev_name or str(i))\n        task = asyncio.create_task(\n            run_single_evaluation(evaluator, cfg, i_batch, sampling_client, evaluator_label),\n            name=f\"eval_{evaluator_label}_iteration_{i_batch:06d}\",\n        )\n        tasks.append(task)\n\n    # Wait for all to complete\n    results = await asyncio.gather(*tasks)\n\n    # Merge all metrics\n    metrics = {}\n    for result in results:\n        metrics.update(result)\n\n    return metrics\n\n\n@trace.scope\nasync def do_sync_training_with_stream_minibatch(\n    start_batch: int,\n    end_batch: int,\n    num_batches: int,\n    cfg: Config,\n    training_client: tinker.TrainingClient,\n    kl_reference_client: tinker.SamplingClient | None,\n    evaluators: list[SamplingClientEvaluator],\n    dataset: RLDataset,\n    ml_logger: ml_log.Logger,\n    tokenizer: Tokenizer,\n    error_counter: RolloutErrorCounter | None = None,\n    strategy: RolloutStrategy | None = None,\n):\n    \"\"\"\n    Implements fully synchronous on-policy training with minibatch streaming.\n    Once we have accumulated enough trajectories for a minibatch, we will\n    immediately train on them, instead of waiting for the full batch of\n    trajectories to be ready. This allows us to overlap sampling and training.\n    \"\"\"\n    # Initial sampling client\n    sampling_client, _ = await save_checkpoint_and_get_sampling_client(\n        training_client, start_batch, cfg.log_path, cfg.save_every, start_batch, cfg.ttl_seconds\n    )\n\n    for i_batch in range(start_batch, end_batch):\n        metrics = {\n            \"progress/batch\": i_batch,\n            \"optim/lr\": cfg.learning_rate,\n            \"progress/done_frac\": (i_batch + 1) / num_batches,\n        }\n        t_start = time.time()\n\n        # Run evaluations\n        if (cfg.eval_every > 0 and i_batch % cfg.eval_every == 0) or i_batch == end_batch - 1:\n            with timed(\"run_evals\", metrics):\n                eval_metrics = await run_evaluations_parallel(\n                    evaluators, sampling_client, cfg, i_batch\n                )\n                metrics.update(eval_metrics)\n\n        with _get_logtree_scope(\n            cfg.log_path,\n            cfg.num_groups_to_log,\n            f\"train_iteration_{i_batch:06d}\",\n            f\"RL Iteration {i_batch}\",\n        ):\n            # Samplers will produce trajectory groups asynchronously,\n            # and the trainer will consume them as soon as they are ready\n            trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]()\n            env_group_builders_P = dataset.get_batch(i_batch)\n\n            @trace.scope\n            async def trajectory_group_worker_task(\n                builder: EnvGroupBuilder, enable_logging: bool\n            ) -> None:\n                metrics = {}\n                t_start = time.time()\n                trajectory_group = await do_group_rollout_and_filter_constant_reward(\n                    sampling_client,\n                    builder,\n                    max_tokens=cfg.max_tokens,\n                    temperature=cfg.temperature,\n                    do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,\n                    enable_logging=enable_logging,\n                    strategy=strategy,\n                )\n                metrics[\"time/trajectory_group_worker_loop/total\"] = time.time() - t_start\n                # Ingest error info (safe: same event loop thread)\n                if error_counter is not None:\n                    error_counter.ingest(trajectory_group)\n                if trajectory_group is not None:\n                    trajectory_groups_queue.put_nowait(\n                        WrappedTrajectoryGroup(\n                            trajectory_group=trajectory_group,\n                            env_group_builder=builder,\n                            sampling_client_step=i_batch,\n                            metrics=metrics,\n                        )\n                    )\n                else:\n                    trajectory_groups_queue.put_nowait(None)\n\n            # Sample all trajectories asynchronously. If we have multiple minibatches,\n            # then sampling can overlap with training.\n            for i, builder in enumerate(env_group_builders_P):\n                asyncio.create_task(\n                    trajectory_group_worker_task(builder, enable_logging=i < cfg.num_groups_to_log),\n                    name=f\"trajectory_group_worker_task_{i}\",\n                )\n\n            # Run multiple optimizer substeps per training iteration\n            streaming_result = await do_train_step_streaming_and_get_sampling_client(\n                cfg,\n                i_batch,\n                trajectory_groups_queue,\n                training_client,\n                kl_reference_client,\n                tokenizer,\n            )\n            # _Shutdown cannot appear in the sync path's local queue\n            assert streaming_result is not None, \"Unexpected shutdown in sync streaming path\"\n            (\n                sampling_client,\n                full_batch_metrics,\n                full_batch_wrapped_trajectory_groups,\n            ) = streaming_result\n\n        _maybe_export_rollout_summary_jsonl(\n            cfg=cfg,\n            file_prefix=f\"train_iteration_{i_batch:06d}\",\n            split=\"train\",\n            iteration=i_batch,\n            groups_P=[\n                RolloutSummaryGroup(\n                    trajectory_group=group.trajectory_group,\n                    tags=group.env_group_builder.logging_tags(),\n                    sampling_client_step=group.sampling_client_step,\n                )\n                for group in full_batch_wrapped_trajectory_groups\n            ],\n        )\n\n        # Log metrics\n        metrics.update(full_batch_metrics)\n        if error_counter is not None:\n            metrics.update(error_counter.get_metrics())\n        metrics[\"time/total\"] = time.time() - t_start\n        ml_logger.log_metrics(metrics, step=i_batch)\n\n\n@chz.chz\nclass WrappedTrajectoryGroup:\n    \"\"\"\n    A wrapper around a trajectory group that includes metadata about how it was generated.\n    Used when we need to overlap sampling and training.\n    \"\"\"\n\n    trajectory_group: TrajectoryGroup\n    # The env group builder that produced the trajectory group.\n    # Pass this along in case the sampler is too stale, and we need to\n    # requeue this group.\n    env_group_builder: EnvGroupBuilder\n    # The step that produced this trajectory group.\n    sampling_client_step: int\n    metrics: dict[str, Any] = chz.field(default_factory=dict)\n\n\n@dataclass\nclass _Shutdown:\n    \"\"\"Sentinel value to signal graceful shutdown through async queues.\n\n    Used in the cascading shutdown protocol for async RL training:\n    dataloader -> workers -> training loop -> evaluation loop.\n    \"\"\"\n\n    pass\n\n\nclass _AsyncCounter:\n    \"\"\"Async-safe counter for tracking the number of alive worker tasks.\"\"\"\n\n    def __init__(self, start: int):\n        self._value = start\n        self._lock = asyncio.Lock()\n\n    async def decrement_and_get(self) -> int:\n        async with self._lock:\n            self._value -= 1\n            return self._value\n\n\n@trace.scope\nasync def do_async_training(\n    start_batch: int,\n    end_batch: int,\n    num_batches: int,\n    cfg: Config,\n    training_client: tinker.TrainingClient,\n    kl_reference_client: tinker.SamplingClient | None,\n    evaluators: list[SamplingClientEvaluator],\n    dataset: RLDataset,\n    ml_logger: ml_log.Logger,\n    tokenizer: Tokenizer,\n    error_counter: RolloutErrorCounter | None = None,\n    strategy: RolloutStrategy | None = None,\n):\n    \"\"\"Implements async off-policy training, capped at K steps off policy.\"\"\"\n    assert cfg.async_config is not None\n\n    # We will have groups_per_batch workers generating rollouts, so cap the\n    # queue size to be groups_per_batch.\n    env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | _Shutdown](\n        maxsize=cfg.async_config.groups_per_batch\n    )\n    trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | _Shutdown | None]()\n\n    # Initial sampling client to use\n    path_dict = await checkpoint_utils.save_checkpoint_async(\n        training_client=training_client,\n        name=f\"{start_batch:06d}\",\n        log_path=cfg.log_path,\n        loop_state={\"batch\": start_batch},\n        kind=\"both\",\n        ttl_seconds=cfg.ttl_seconds,\n    )\n\n    # Shutdown coordination — cascading sequence:\n    # 1. Dataloader exhausts data → sets dataloader_done_event (prevents requeuing stale\n    #    samples) and enqueues one _Shutdown sentinel per worker into env_group_builders_queue.\n    # 2. Each trajectory worker receives its _Shutdown sentinel → exits and decrements\n    #    worker_alive_counter. The last worker enqueues a _Shutdown into trajectory_groups_queue.\n    # 3. Training loop receives _Shutdown from trajectory_groups_queue → finishes current\n    #    batch, sets evaluation_loop_should_shutdown_event, and exits.\n    # 4. Eval loop sees evaluation_loop_should_shutdown_event → exits.\n    dataloader_done_event = asyncio.Event()\n    evaluation_loop_should_shutdown_event = asyncio.Event()\n    worker_alive_counter = _AsyncCounter(cfg.async_config.groups_per_batch)\n\n    # This will be updated by the training loop\n    sampling_client = training_client.create_sampling_client(path_dict[\"sampler_path\"])\n    sampling_client_step = start_batch\n    sampling_client_updated_event = asyncio.Event()\n    sampling_client_updated_event.set()\n\n    @trace.scope\n    async def dataloader_loop():\n        \"\"\"Gets the next set of env builders to run\"\"\"\n        i_batch = start_batch\n        while i_batch < end_batch:\n            env_group_builders_P = dataset.get_batch(i_batch)\n            for env_group_builder in env_group_builders_P:\n                await env_group_builders_queue.put(env_group_builder)\n            i_batch += 1\n\n        # Signal that no more data will be produced, so stale samples should not be requeued\n        dataloader_done_event.set()\n        # Enqueue shutdown sentinels — one per worker — to cascade the shutdown\n        logger.info(\"[dataloader_loop] No more data, shutting down trajectory group workers\")\n        assert cfg.async_config is not None\n        for _ in range(cfg.async_config.groups_per_batch):\n            await env_group_builders_queue.put(_Shutdown())\n        logger.info(\"[dataloader_loop] Terminated\")\n\n    @trace.scope\n    async def trajectory_group_worker_loop():\n        \"\"\"Generates trajectories for a single env builder\"\"\"\n        while True:\n            env_group_builder = await env_group_builders_queue.get()\n            if isinstance(env_group_builder, _Shutdown):\n                logger.info(\"[trajectory_group_worker_loop] Received shutdown signal\")\n                break\n\n            metrics = {}\n            t_start = time.time()\n            # Save a reference to the sampling client step in case it changes\n            # while we're running the rollout\n            sampling_client_step_copy = sampling_client_step\n            trajectory_group = await do_group_rollout_and_filter_constant_reward(\n                sampling_client,\n                env_group_builder,\n                max_tokens=cfg.max_tokens,\n                temperature=cfg.temperature,\n                do_remove_constant_reward_groups=cfg.remove_constant_reward_groups,\n                strategy=strategy,\n            )\n            # Ingest error info (safe: same event loop thread)\n            if error_counter is not None:\n                error_counter.ingest(trajectory_group)\n            if trajectory_group is None:\n                trajectory_groups_queue.put_nowait(None)\n            else:\n                metrics[\"time/trajectory_group_worker_loop/total\"] = time.time() - t_start\n                trajectory_groups_queue.put_nowait(\n                    WrappedTrajectoryGroup(\n                        trajectory_group=trajectory_group,\n                        env_group_builder=env_group_builder,\n                        sampling_client_step=sampling_client_step_copy,\n                        metrics=metrics,\n                    )\n                )\n\n        # When this is the last worker to exit, signal the training loop to shut down\n        num_alive = await worker_alive_counter.decrement_and_get()\n        if num_alive == 0:\n            logger.info(\n                \"[trajectory_group_worker_loop] Last worker exited, shutting down training loop\"\n            )\n            trajectory_groups_queue.put_nowait(_Shutdown())\n        logger.info(\"[trajectory_group_worker_loop] Terminated\")\n\n    @trace.scope\n    async def training_loop():\n        \"\"\"\n        Waits for a sufficient number of valid trajectories to be accumulated and trains on them.\n        Will discard trajectories that are too stale.\n        \"\"\"\n        assert cfg.async_config is not None\n\n        i_batch = start_batch\n        wrapped_trajectory_groups = []\n        while i_batch < end_batch:\n\n            @trace.scope\n            def filter_stale_trajectory_group(\n                wrapped_trajectory_group: WrappedTrajectoryGroup | None,\n            ) -> bool:\n                \"\"\"Returns False if the trajectory group is too stale or not valid\"\"\"\n                if wrapped_trajectory_group is None:\n                    return False\n\n                # If the samples are too stale, requeue the data so that it will be used eventually.\n                # Skip requeuing during shutdown to avoid deadlocking on a full bounded queue.\n                assert cfg.async_config is not None\n                if (\n                    i_batch - wrapped_trajectory_group.sampling_client_step\n                    > cfg.async_config.max_steps_off_policy\n                ):\n                    if dataloader_done_event.is_set():\n                        logger.info(\n                            f\"[training_loop] Step {i_batch}: Samples are too stale, \"\n                            \"discarding (dataloader done)\"\n                        )\n                    else:\n                        logger.info(\n                            f\"[training_loop] Step {i_batch}: Samples are too stale, requeuing\"\n                        )\n                        asyncio.create_task(\n                            env_group_builders_queue.put(\n                                wrapped_trajectory_group.env_group_builder\n                            ),\n                            name=\"requeue_stale_sample_task\",\n                        )\n                    return False\n                return True\n\n            metrics = {\n                \"training_client/step\": i_batch,\n                \"optim/lr\": cfg.learning_rate,\n                \"progress/done_frac\": (i_batch + 1) / num_batches,\n            }\n            t_start = time.time()\n\n            nonlocal sampling_client\n            nonlocal sampling_client_step\n            if cfg.stream_minibatch_config is not None:\n                # Streaming minibatch: delegate queue consumption to the streaming function.\n                # We need to check for shutdown before entering the streaming function,\n                # since it will block on queue.get() internally.\n                wrapped_trajectory_group = await trajectory_groups_queue.get()\n                if isinstance(wrapped_trajectory_group, _Shutdown):\n                    logger.info(\"[training_loop] Received shutdown signal\")\n                    break\n                if wrapped_trajectory_group is None:\n                    continue\n                await trajectory_groups_queue.put(wrapped_trajectory_group)\n                streaming_result = await do_train_step_streaming_and_get_sampling_client(\n                    cfg,\n                    i_batch,\n                    trajectory_groups_queue,\n                    training_client,\n                    kl_reference_client,\n                    tokenizer,\n                    filter_stale_trajectory_group,\n                )\n                if streaming_result is None:\n                    logger.info(\"[training_loop] Received shutdown signal from streaming\")\n                    break\n                (\n                    sampling_client,\n                    train_step_metrics,\n                    full_batch_wrapped_trajectory_groups,\n                ) = streaming_result\n                _maybe_export_rollout_summary_jsonl(\n                    cfg=cfg,\n                    file_prefix=f\"train_iteration_{i_batch:06d}\",\n                    split=\"train\",\n                    iteration=i_batch,\n                    groups_P=[\n                        RolloutSummaryGroup(\n                            trajectory_group=group.trajectory_group,\n                            tags=group.env_group_builder.logging_tags(),\n                            sampling_client_step=group.sampling_client_step,\n                        )\n                        for group in full_batch_wrapped_trajectory_groups\n                    ],\n                )\n            else:\n                wrapped_trajectory_group = await trajectory_groups_queue.get()\n                if isinstance(wrapped_trajectory_group, _Shutdown):\n                    logger.info(\"[training_loop] Received shutdown signal\")\n                    break\n                if wrapped_trajectory_group is None:\n                    continue\n\n                if not filter_stale_trajectory_group(wrapped_trajectory_group):\n                    continue\n\n                # Dynamic sampling: Wait for enough trajectories to accumulate to\n                # ensure all batch sizes are the same size. This avoids needing to adjust\n                # the learning rate for different batch sizes.\n                wrapped_trajectory_groups.append(wrapped_trajectory_group)\n                if len(wrapped_trajectory_groups) < cfg.async_config.groups_per_batch:\n                    continue\n                logger.info(\n                    f\"[training_loop] Step {i_batch}: Will train on batch, num groups: {len(wrapped_trajectory_groups)}\"\n                )\n\n                # Compute sampling client metrics, as samples may have been generated with\n                # different sampler versions\n                metrics.update(compute_sampling_client_metrics(wrapped_trajectory_groups))\n\n                # TODO: For proper checkpointing, we also need to save dataloader state and\n                # all queued trajectory groups that haven't been trained on yet\n                sampling_client, train_step_metrics = await do_train_step_and_get_sampling_client(\n                    cfg,\n                    i_batch,\n                    training_client,\n                    kl_reference_client,\n                    tokenizer,\n                    [g.env_group_builder for g in wrapped_trajectory_groups],\n                    [g.trajectory_group for g in wrapped_trajectory_groups],\n                )\n                _maybe_export_rollout_summary_jsonl(\n                    cfg=cfg,\n                    file_prefix=f\"train_iteration_{i_batch:06d}\",\n                    split=\"train\",\n                    iteration=i_batch,\n                    groups_P=[\n                        RolloutSummaryGroup(\n                            trajectory_group=group.trajectory_group,\n                            tags=group.env_group_builder.logging_tags(),\n                            sampling_client_step=group.sampling_client_step,\n                        )\n                        for group in wrapped_trajectory_groups\n                    ],\n                )\n            sampling_client_step = i_batch + 1\n            sampling_client_updated_event.set()\n\n            # Log metrics\n            metrics.update(train_step_metrics)\n            if error_counter is not None:\n                metrics.update(error_counter.get_metrics())\n            metrics[\"time/training_loop/total\"] = time.time() - t_start\n            ml_logger.log_metrics(metrics, step=i_batch)\n            i_batch += 1\n            wrapped_trajectory_groups = []\n\n        # Signal evaluation loop to shut down\n        evaluation_loop_should_shutdown_event.set()\n        sampling_client_updated_event.set()\n        logger.info(\"[training_loop] Terminated\")\n\n    @trace.scope\n    async def evaluation_loop():\n        \"\"\"Runs evals periodically\"\"\"\n        if len(evaluators) == 0 or cfg.eval_every == 0:\n            return\n\n        while not evaluation_loop_should_shutdown_event.is_set():\n            await sampling_client_updated_event.wait()\n            sampling_client_updated_event.clear()\n\n            metrics = {}\n            t_start = time.time()\n            # Save a reference to the original values in case it changes\n            # while we're running the evals\n            sampling_client_eval_step = sampling_client_step\n            sampling_client_eval = sampling_client\n            if cfg.eval_every > 0 and sampling_client_eval_step % cfg.eval_every == 0:\n                with timed(\"run_evals\", metrics):\n                    for evaluator in evaluators:\n                        eval_metrics = await evaluator(sampling_client_eval)\n                        metrics.update({f\"test/{k}\": v for k, v in eval_metrics.items()})\n                metrics[\"time/evaluation_loop/total\"] = time.time() - t_start\n                ml_logger.log_metrics(metrics, step=sampling_client_eval_step)\n        logger.info(\"[evaluation_loop] Terminated\")\n\n    await asyncio.gather(\n        asyncio.create_task(dataloader_loop(), name=\"dataloader_loop\"),\n        *[\n            asyncio.create_task(\n                trajectory_group_worker_loop(), name=f\"trajectory_group_worker_loop_{i}\"\n            )\n            for i in range(cfg.async_config.groups_per_batch)\n        ],\n        asyncio.create_task(training_loop(), name=\"training_loop\"),\n        asyncio.create_task(evaluation_loop(), name=\"evaluation_loop\"),\n    )\n\n\n@trace.scope\nasync def save_checkpoint_and_get_sampling_client(\n    training_client: tinker.TrainingClient,\n    i_batch: int,\n    log_path: str,\n    save_every: int,\n    start_batch: int = 0,\n    ttl_seconds: int | None = None,\n) -> tuple[tinker.SamplingClient, dict[str, Any]]:\n    metrics = {}\n    with timed(\"save_checkpoint\", metrics):\n        if save_every > 0 and i_batch > start_batch and i_batch % save_every == 0:\n            path_dict = await checkpoint_utils.save_checkpoint_async(\n                training_client=training_client,\n                name=f\"{i_batch:06d}\",\n                log_path=log_path,\n                loop_state={\"batch\": i_batch},\n                kind=\"both\",\n                ttl_seconds=ttl_seconds,\n            )\n            return training_client.create_sampling_client(path_dict[\"sampler_path\"]), metrics\n        else:\n            return await training_client.save_weights_and_get_sampling_client_async(), metrics\n\n\n@trace.scope\nasync def prepare_minibatch(\n    env_group_builders_P: Sequence[EnvGroupBuilder],\n    trajectory_groups_P: list[TrajectoryGroup],\n    tokenizer: Tokenizer,\n    kl_reference_client: tinker.SamplingClient | None,\n    kl_penalty_coef: float,\n    kl_discount_factor: float,\n) -> tuple[list[tinker.Datum], dict[str, Any]]:\n    \"\"\"Converts the trajectories into a minibatch, and provides metrics about the minibatch\"\"\"\n\n    # Compute trajectory metrics\n    metrics = {}\n    taglist_P = [env_group_builder.logging_tags() for env_group_builder in env_group_builders_P]\n    metrics.update(compute_trajectory_metrics(trajectory_groups_P, taglist_P))\n\n    # Print up to two trajectory groups\n    for traj_group in trajectory_groups_P[:2]:\n        print_group(traj_group, tokenizer)\n\n    # Assemble training data\n    with timed(\"assemble_training_data\", metrics):\n        advantages_P = compute_advantages(trajectory_groups_P)\n        data_D, _metadata_D = assemble_training_data(trajectory_groups_P, advantages_P)\n\n    # Incorporate KL penalty if configured\n    if kl_penalty_coef > 0 and kl_reference_client is not None:\n        with timed(\"kl_vs_base\", metrics):\n            kl_penalty_metrics = await incorporate_kl_penalty(\n                data_D,\n                kl_reference_client,\n                kl_penalty_coef,\n                kl_discount_factor,\n            )\n        metrics.update(kl_penalty_metrics)\n\n    return data_D, metrics\n\n\n@trace.scope\nasync def compute_full_batch_metrics_and_get_sampling_client(\n    training_client: tinker.TrainingClient,\n    i_batch: int,\n    data_D: list[tinker.Datum],\n    training_logprobs_D: list[torch.Tensor],\n    log_path: str,\n    save_every: int,\n    do_compute_post_kl: bool,\n    ttl_seconds: int | None = None,\n) -> tuple[tinker.SamplingClient, dict[str, Any]]:\n    \"\"\"\n    At the end of the iteration, this will compute metrics for the full batch\n    and return the latest sampling client.\n\n    The reason we return a sampling client is that if do_compute_post_kl is True,\n    we need to create a sampling client from the post-update policy.\n    \"\"\"\n    metrics = {}\n\n    # Compute KL metrics\n    with timed(\"compute_kl_sample_train\", metrics):\n        kl_sample_train_metrics = compute_kl_sample_train(data_D, training_logprobs_D)\n        metrics.update(kl_sample_train_metrics)\n\n    # Get a sampling client using the new weights\n    sampling_client, checkpoint_metrics = await save_checkpoint_and_get_sampling_client(\n        training_client, i_batch, log_path, save_every, ttl_seconds=ttl_seconds\n    )\n    metrics.update(checkpoint_metrics)\n\n    # Compute post-KL metrics if configured\n    if do_compute_post_kl:\n        with timed(\"compute_post_kl\", metrics):\n            post_kl_metrics = await compute_post_kl(data_D, sampling_client)\n            metrics.update(post_kl_metrics)\n\n    return sampling_client, metrics\n\n\n@trace.scope\nasync def do_train_step_streaming_and_get_sampling_client(\n    cfg: Config,\n    i_batch: int,\n    trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | _Shutdown | None],\n    training_client: tinker.TrainingClient,\n    kl_reference_client: tinker.SamplingClient | None,\n    tokenizer: Tokenizer,\n    trajectory_group_filter: Callable[[WrappedTrajectoryGroup | None], bool] = lambda _: True,\n) -> tuple[tinker.SamplingClient, dict[str, Any], list[WrappedTrajectoryGroup]] | None:\n    \"\"\"\n    As soon as we have enough trajectories for a minibatch, we will train on them.\n    This allows us to overlap sampling and training.\n\n    Returns None if a shutdown sentinel is received, indicating the caller should\n    stop training.\n    \"\"\"\n    assert cfg.stream_minibatch_config is not None\n    assert cfg.stream_minibatch_config.groups_per_batch % cfg.num_substeps == 0, (\n        f\"{cfg.stream_minibatch_config.groups_per_batch=} must be divisible by {cfg.num_substeps=}\"\n    )\n    # Number of groups across all minibatches in each optimizer substep\n    groups_per_substep = cfg.stream_minibatch_config.groups_per_batch // cfg.num_substeps\n    assert groups_per_substep % cfg.stream_minibatch_config.num_minibatches == 0, (\n        f\"{groups_per_substep} must be divisible by {cfg.stream_minibatch_config.num_minibatches=}\"\n    )\n    # Number of groups per minibatch in each optimizer substep\n    groups_per_minibatch = groups_per_substep // cfg.stream_minibatch_config.num_minibatches\n\n    trace.update_scope_context({\"step\": i_batch})\n\n    metrics = {}\n\n    # Run multiple optimizer substeps per training iteration\n    all_data_D = []\n    all_training_logprobs_D = []\n    all_wrapped_trajectory_groups = []\n    for i_substep in range(cfg.num_substeps):\n        # Run multiple minibatches per substep\n        # Once we have enough trajectories for a minibatch, train on them\n        wrapped_trajectory_groups = []\n        forward_backward_futures: list[tinker.APIFuture[tinker.ForwardBackwardOutput]] = []\n        i_minibatch = 0\n        while i_minibatch < cfg.stream_minibatch_config.num_minibatches:\n            wrapped_trajectory_group = await trajectory_groups_queue.get()\n            if isinstance(wrapped_trajectory_group, _Shutdown):\n                logger.info(\"[do_train_step_streaming] Received shutdown signal\")\n                return None\n            if not trajectory_group_filter(wrapped_trajectory_group):\n                continue\n            wrapped_trajectory_groups.append(wrapped_trajectory_group)\n\n            if len(wrapped_trajectory_groups) < groups_per_minibatch:\n                continue\n            logger.info(\n                f\"[stream_minibatch] Step {i_batch}, Substep {i_substep}/{cfg.num_substeps}, Minibatch {i_minibatch}/{cfg.stream_minibatch_config.num_minibatches}: Will train on minibatch, num groups: {len(wrapped_trajectory_groups)}\"\n            )\n\n            # Note: we may have removed trajectory groups that have the same reward.\n            # To have the same results as the sync implementation, we will\n            # remove these and train on a smaller batch.\n            wrapped_trajectory_groups = [g for g in wrapped_trajectory_groups if g is not None]\n            if len(wrapped_trajectory_groups) == 0:\n                i_minibatch += 1\n                continue\n\n            data_D, prepare_minibatch_metrics = await prepare_minibatch(\n                [g.env_group_builder for g in wrapped_trajectory_groups],\n                [g.trajectory_group for g in wrapped_trajectory_groups],\n                tokenizer,\n                kl_reference_client,\n                kl_penalty_coef=cfg.kl_penalty_coef,\n                kl_discount_factor=cfg.kl_discount_factor,\n            )\n            metrics.update(prepare_minibatch_metrics)\n\n            # Enqueue forward-backward (we'll await results after all minibatches are enqueued)\n            with timed(f\"train/fwd_bwd_substep_{i_substep}_mb_{i_minibatch}_enqueue\", metrics):\n                forward_backward_futures.append(\n                    await training_client.forward_backward_async(\n                        [_remove_mask(d) for d in data_D],\n                        loss_fn=cfg.loss_fn,\n                        loss_fn_config=cfg.loss_fn_config,\n                    )\n                )\n            all_data_D.extend(data_D)\n            all_wrapped_trajectory_groups.extend(wrapped_trajectory_groups)\n            i_minibatch += 1\n            wrapped_trajectory_groups = []\n\n        # Enqueue optim_step before awaiting results (so they land on same clock cycle)\n        adam_params = tinker.AdamParams(\n            learning_rate=cfg.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8\n        )\n        with timed(f\"train/optim_substep_{i_substep}_enqueue\", metrics):\n            optim_future = await training_client.optim_step_async(adam_params)\n\n        # Now consume all forward-backward results\n        for i_mb, fwd_bwd_future in enumerate(forward_backward_futures):\n            with timed(f\"train/fwd_bwd_substep_{i_substep}_mb_{i_mb}_consume\", metrics):\n                fwd_bwd_result = await fwd_bwd_future.result_async()\n                all_training_logprobs_D.extend(_training_logprobs_from_fwd_bwd(fwd_bwd_result))\n\n        with timed(f\"train/optim_substep_{i_substep}_consume\", metrics):\n            optim_result = await optim_future.result_async()\n\n        if optim_result.metrics:\n            metrics.update(optim_result.metrics)\n\n    # Aggregate metrics across the entire batch\n    metrics.update(compute_sampling_client_metrics(all_wrapped_trajectory_groups))\n    metrics.update(\n        compute_trajectory_metrics(\n            [g.trajectory_group for g in all_wrapped_trajectory_groups],\n            [g.env_group_builder.logging_tags() for g in all_wrapped_trajectory_groups],\n        )\n    )\n    (\n        sampling_client,\n        full_batch_metrics,\n    ) = await compute_full_batch_metrics_and_get_sampling_client(\n        training_client,\n        # NOTE: saving the checkpoint as the i + 1 step\n        i_batch + 1,\n        all_data_D,\n        all_training_logprobs_D,\n        cfg.log_path,\n        cfg.save_every,\n        cfg.compute_post_kl,\n        cfg.ttl_seconds,\n    )\n    metrics.update(full_batch_metrics)\n    return sampling_client, metrics, all_wrapped_trajectory_groups\n\n\n@trace.scope\nasync def do_train_step_and_get_sampling_client(\n    cfg: Config,\n    i_batch: int,\n    training_client: tinker.TrainingClient,\n    kl_reference_client: tinker.SamplingClient | None,\n    tokenizer: Tokenizer,\n    env_group_builders_P: Sequence[EnvGroupBuilder],\n    trajectory_groups_P: list[TrajectoryGroup],\n) -> tuple[tinker.SamplingClient, dict[str, Any]]:\n    trace.update_scope_context({\"step\": i_batch})\n\n    metrics = {}\n    data_D, prepare_minibatch_metrics = await prepare_minibatch(\n        env_group_builders_P,\n        trajectory_groups_P,\n        tokenizer,\n        kl_reference_client,\n        kl_penalty_coef=cfg.kl_penalty_coef,\n        kl_discount_factor=cfg.kl_discount_factor,\n    )\n    metrics.update(prepare_minibatch_metrics)\n\n    with timed(\"train\", metrics):\n        training_logprobs_D = await train_step(\n            data_D=data_D,\n            training_client=training_client,\n            learning_rate=cfg.learning_rate,\n            num_substeps=cfg.num_substeps,\n            loss_fn=cfg.loss_fn,\n            loss_fn_config=cfg.loss_fn_config,\n            metrics=metrics,\n        )\n\n    sampling_client, full_batch_metrics = await compute_full_batch_metrics_and_get_sampling_client(\n        training_client,\n        # NOTE: saving the checkpoint as the i + 1 step\n        i_batch + 1,\n        data_D,\n        training_logprobs_D,\n        cfg.log_path,\n        cfg.save_every,\n        cfg.compute_post_kl,\n        cfg.ttl_seconds,\n    )\n    metrics.update(full_batch_metrics)\n\n    return sampling_client, metrics\n\n\n@trace.scope\nasync def do_sync_training(\n    start_batch: int,\n    end_batch: int,\n    num_batches: int,\n    cfg: Config,\n    training_client: tinker.TrainingClient,\n    kl_reference_client: tinker.SamplingClient | None,\n    evaluators: list[SamplingClientEvaluator],\n    dataset: RLDataset,\n    ml_logger: ml_log.Logger,\n    tokenizer: Tokenizer,\n    error_counter: RolloutErrorCounter | None = None,\n    strategy: RolloutStrategy | None = None,\n):\n    \"\"\"Implements fully synchronous on-policy training\"\"\"\n    # Initial sampling client\n    sampling_client, _ = await save_checkpoint_and_get_sampling_client(\n        training_client, start_batch, cfg.log_path, cfg.save_every, start_batch, cfg.ttl_seconds\n    )\n\n    for i_batch in range(start_batch, end_batch):\n        metrics: dict[str, Any] = {\n            \"progress/batch\": i_batch,\n            \"optim/lr\": cfg.learning_rate,\n            \"progress/done_frac\": (i_batch + 1) / num_batches,\n        }\n\n        with trace.trace_iteration(step=i_batch) as window:\n            # Run evaluations\n            if cfg.eval_every > 0 and i_batch % cfg.eval_every == 0:\n                eval_metrics = await run_evaluations_parallel(\n                    evaluators, sampling_client, cfg, i_batch\n                )\n                metrics.update(eval_metrics)\n\n            # Get batch and sample trajectories\n            env_group_builders_P = dataset.get_batch(i_batch)\n\n            # Initialize logtree trace for this iteration if logging is enabled\n            with _get_logtree_scope(\n                log_path=cfg.log_path,\n                num_groups_to_log=cfg.num_groups_to_log,\n                f_name=f\"train_iteration_{i_batch:06d}\",\n                scope_name=f\"RL Iteration {i_batch}\",\n            ):\n                # Note: do_remove_constant_reward_groups=False here because we remove\n                # constant reward groups after all rollouts are collected (below)\n                results_P = await gather_with_progress(\n                    (\n                        do_group_rollout_and_filter_constant_reward(\n                            sampling_client,\n                            builder,\n                            max_tokens=cfg.max_tokens,\n                            temperature=cfg.temperature,\n                            do_remove_constant_reward_groups=False,\n                            enable_logging=i < cfg.num_groups_to_log,\n                            strategy=strategy,\n                        )\n                        for i, builder in enumerate(env_group_builders_P)\n                    ),\n                    desc=f\"Sampling batch {i_batch}\",\n                )\n\n            # Ingest error info from results\n            if error_counter is not None:\n                for result in results_P:\n                    error_counter.ingest(result)\n\n            # Filter out None results (from errored or fully-failed groups)\n            successful = [\n                (builder, tg)\n                for builder, tg in safezip(env_group_builders_P, results_P)\n                if tg is not None\n            ]\n            batch_skipped = not successful\n            if batch_skipped:\n                logger.warning(f\"Batch {i_batch}: all groups failed or filtered, skipping batch\")\n            else:\n                env_group_builders_P = [s[0] for s in successful]\n                trajectory_groups_P: list[TrajectoryGroup] = [s[1] for s in successful]\n\n                _maybe_export_rollout_summary_jsonl(\n                    cfg=cfg,\n                    file_prefix=f\"train_iteration_{i_batch:06d}\",\n                    split=\"train\",\n                    iteration=i_batch,\n                    groups_P=[\n                        RolloutSummaryGroup(\n                            trajectory_group=trajectory_group,\n                            tags=env_group_builder.logging_tags(),\n                            sampling_client_step=i_batch,\n                        )\n                        for env_group_builder, trajectory_group in safezip(\n                            env_group_builders_P, trajectory_groups_P\n                        )\n                    ],\n                )\n\n                if cfg.remove_constant_reward_groups:\n                    trajectory_groups_P = remove_constant_reward_groups(trajectory_groups_P)\n\n                # Train step\n                sampling_client, train_step_metrics = await do_train_step_and_get_sampling_client(\n                    cfg,\n                    i_batch,\n                    training_client,\n                    kl_reference_client,\n                    tokenizer,\n                    env_group_builders_P,\n                    trajectory_groups_P,\n                )\n\n                metrics.update(train_step_metrics)\n\n        metrics.update(window.get_timing_metrics())\n        if error_counter is not None:\n            metrics.update(error_counter.get_metrics())\n        window.write_spans_jsonl(Path(cfg.log_path) / \"timing_spans.jsonl\", step=i_batch)\n        if cfg.span_chart_every > 0 and i_batch % cfg.span_chart_every == 0:\n            trace.save_gantt_chart_html(\n                window, i_batch, Path(cfg.log_path) / f\"timing_gantt_{i_batch:06d}.html\"\n            )\n        ml_logger.log_metrics(metrics, step=i_batch)\n\n\n@trace.scope\nasync def main(\n    cfg: Config,\n    rollout_executor: Executor | None = None,\n):\n    \"\"\"Main training loop for MDP RL.\n\n    Args:\n        cfg: Training configuration.\n        rollout_executor: Optional ``concurrent.futures.Executor`` for offloading\n            group rollouts to separate processes or remote workers. Pass\n            ``ProcessPoolExecutor(max_workers=N)`` for multi-process execution,\n            or any custom ``Executor`` (Ray, cluster dispatchers, etc.).\n            Default ``None`` runs rollouts as asyncio coroutines in-process.\n    \"\"\"\n    if rollout_executor is not None:\n        set_rollout_executor(rollout_executor)\n    ml_logger = ml_log.setup_logging(\n        log_dir=cfg.log_path,\n        wandb_project=cfg.wandb_project,\n        config=cfg,\n        wandb_name=cfg.wandb_name,\n    )\n    if cfg.enable_trace:\n        # Get and rename the current (main) task\n        current_task = asyncio.current_task()\n        if current_task is not None:\n            current_task.set_name(\"main\")\n        trace_events_path = str(Path(cfg.log_path) / \"trace_events.jsonl\")\n        logger.info(f\"Tracing is enabled. Trace events will be saved to {trace_events_path}\")\n        logger.info(\n            f\"Run `python tinker_cookbook/utils/trace.py {trace_events_path} trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/\"\n        )\n        trace.trace_init(output_file=trace_events_path)\n\n    logging.getLogger(\"httpx\").setLevel(logging.WARNING)\n    logging.getLogger(\"pylatexenc\").setLevel(logging.WARNING)\n\n    resume_info = checkpoint_utils.get_last_checkpoint(cfg.log_path)\n    if resume_info:\n        start_batch = resume_info.batch\n    else:\n        start_batch = 0\n\n    service_client = tinker.ServiceClient(base_url=cfg.base_url)\n    user_metadata: dict[str, str] = {}\n    if wandb_link := ml_logger.get_logger_url():\n        user_metadata[\"wandb_link\"] = wandb_link\n    checkpoint_utils.add_renderer_name_to_user_metadata(user_metadata, cfg.renderer_name)\n    model_info.warn_if_renderer_not_recommended(cfg.model_name, cfg.renderer_name)\n\n    if resume_info:\n        # Resuming interrupted training - load optimizer state for proper continuation\n        await checkpoint_utils.check_renderer_name_for_checkpoint_async(\n            service_client, resume_info.state_path, cfg.renderer_name\n        )\n        training_client = (\n            await service_client.create_training_client_from_state_with_optimizer_async(\n                resume_info.state_path, user_metadata=user_metadata\n            )\n        )\n        logger.info(f\"Resumed training from {resume_info.state_path}\")\n    elif cfg.load_checkpoint_path:\n        # Starting fresh from a checkpoint - load weights only (fresh optimizer)\n        await checkpoint_utils.check_renderer_name_for_checkpoint_async(\n            service_client, cfg.load_checkpoint_path, cfg.renderer_name\n        )\n        training_client = await service_client.create_training_client_from_state_async(\n            cfg.load_checkpoint_path, user_metadata=user_metadata\n        )\n        logger.info(f\"Loaded weights from {cfg.load_checkpoint_path}\")\n    else:\n        training_client = await service_client.create_lora_training_client_async(\n            cfg.model_name, rank=cfg.lora_rank, user_metadata=user_metadata\n        )\n\n    # Get tokenizer from training client\n    tokenizer = training_client.get_tokenizer()\n\n    # Create dataset from thunk\n    dataset, maybe_test_dataset = await cfg.dataset_builder()\n    # Build rollout strategy and error counter from config\n    strategy = rollout_strategy_from_config(cfg.rollout_error_tolerance)\n    error_counter = RolloutErrorCounter() if strategy.catches_group_errors else None\n\n    evaluators = [evaluator() for evaluator in cfg.evaluator_builders]\n    if maybe_test_dataset is not None:\n        evaluators.append(\n            RLTestSetEvaluator(\n                maybe_test_dataset,\n                max_tokens=cfg.max_tokens,\n                strategy=strategy,\n            )\n        )\n\n    num_batches = len(dataset)\n    end_batch = min(cfg.max_steps, num_batches) if cfg.max_steps is not None else num_batches\n    logger.info(f\"Will train on {end_batch} batches\")\n\n    # Create KL reference client once if KL penalty is enabled\n    if cfg.kl_penalty_coef > 0:\n        if cfg.kl_reference_config is None:\n            raise ConfigurationError(\n                \"kl_reference_config must be specified when kl_penalty_coef > 0\"\n            )\n        kl_reference_client = service_client.create_sampling_client(\n            base_model=cfg.kl_reference_config.base_model,\n            model_path=cfg.kl_reference_config.load_checkpoint_path,\n        )\n    else:\n        kl_reference_client = None\n\n    # Training loop\n    if cfg.async_config is not None:\n        training_func = do_async_training\n    elif cfg.stream_minibatch_config is not None:\n        training_func = do_sync_training_with_stream_minibatch\n    else:\n        training_func = do_sync_training\n    await training_func(\n        start_batch=start_batch,\n        end_batch=end_batch,\n        num_batches=end_batch,\n        cfg=cfg,\n        training_client=training_client,\n        kl_reference_client=kl_reference_client,\n        evaluators=evaluators,\n        dataset=dataset,\n        ml_logger=ml_logger,\n        tokenizer=tokenizer,\n        error_counter=error_counter,\n        strategy=strategy,\n    )\n\n    # Save final checkpoint\n    if start_batch < end_batch:\n        _ = await checkpoint_utils.save_checkpoint_async(\n            training_client=training_client,\n            name=\"final\",\n            log_path=cfg.log_path,\n            kind=\"both\",\n            loop_state={\"batch\": end_batch},\n            ttl_seconds=None,\n        )\n    else:\n        logger.info(\"Training was already complete; nothing to do\")\n\n    # Cleanup\n    if rollout_executor is not None:\n        rollout_executor.shutdown(wait=True)\n        set_rollout_executor(None)\n    ml_logger.close()\n    logger.info(\"Training completed successfully\")\n"
  },
  {
    "path": "tinker_cookbook/rl/types.py",
    "content": "\"\"\"\nBasic interfaces and types for reinforcement learning.\n\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass, field\nfrom typing import TypeAlias\n\nimport chz\nimport tinker\n\nfrom tinker_cookbook.completers import StopCondition, TokensWithLogprobs\nfrom tinker_cookbook.utils.misc_utils import safezip\n\nAction: TypeAlias = list[int]\nObservation: TypeAlias = tinker.ModelInput\nLogprobs: TypeAlias = list[float]\nMetrics: TypeAlias = dict[str, float | int]\nLogs: TypeAlias = dict[str, str | int | float]\n\n\n@dataclass\nclass StepResult:\n    \"\"\"Result returned by Env.step().\"\"\"\n\n    reward: float\n    \"\"\"Immediate reward for this step.\"\"\"\n    episode_done: bool\n    \"\"\"Whether the episode has ended.\"\"\"\n    next_observation: Observation\n    \"\"\"Observation for the next step (or final observation if episode_done).\"\"\"\n    next_stop_condition: StopCondition\n    \"\"\"Stop condition for the next generation.\"\"\"\n    metrics: Metrics = field(default_factory=dict)\n    \"\"\"Numeric values aggregated and reported in training logs (e.g., timing, counts).\"\"\"\n    logs: Logs = field(default_factory=dict)\n    \"\"\"Diagnostic info for display/debugging tools (not aggregated like metrics).\"\"\"\n\n\n@dataclass\nclass Transition:\n    \"\"\"A single (observation, action, reward) tuple from a trajectory.\"\"\"\n\n    ob: Observation\n    \"\"\"Observation the agent saw before taking the action.\"\"\"\n    ac: TokensWithLogprobs\n    \"\"\"Action taken (tokens and their log-probabilities).\"\"\"\n    reward: float\n    \"\"\"Immediate reward received after taking the action.\"\"\"\n    episode_done: bool\n    \"\"\"Whether this transition ended the episode.\"\"\"\n    metrics: Metrics = field(default_factory=dict)\n    \"\"\"Numeric values aggregated and reported in training logs.\"\"\"\n    logs: Logs = field(default_factory=dict)\n    \"\"\"Diagnostic info for display/debugging tools (not aggregated like metrics).\"\"\"\n\n\nclass Env(ABC):\n    \"\"\"\n    Stateful environment that a single agent interacts with.\n    Discard after running for one episode.\n    \"\"\"\n\n    @abstractmethod\n    async def initial_observation(self) -> tuple[Observation, StopCondition]:\n        pass\n\n    @abstractmethod\n    async def step(self, action: Action) -> StepResult:\n        pass\n\n\n@dataclass(frozen=True)\nclass Trajectory:\n    \"\"\"\n    A sequence of observations and actions, resulting from running a single agent in a single\n    environment.\n    \"\"\"\n\n    transitions: list[Transition]\n    final_ob: Observation\n\n\n@dataclass(frozen=True)\nclass RolloutError:\n    \"\"\"A captured error from a failed trajectory rollout.\n\n    Stored on :class:`TrajectoryGroup` so error information flows through\n    return values (including across process boundaries via pickle) without\n    requiring shared mutable state.\n    \"\"\"\n\n    error_type: str\n    \"\"\"The exception class name, e.g. ``'BadRequestError'``.\"\"\"\n    error_message: str\n    \"\"\"``str(exception)`` — the human-readable error description.\"\"\"\n\n\nclass EnvGroupBuilder(ABC):\n    \"\"\"\n    Builds a group of environments. The group will be used in the following way:\n\n    - Algorithms like GRPO will center rewards across the group.\n    - The reward function (compute_group_rewards) has access to the trajectories from the\n      whole group, even though many reward functions will evaluate each one independently.\n\n      - For example, this enables us to use pairwise reward models that look at a pair of\n        trajectories at a time. With such a reward model, we effectively have a multi-agent\n        environment, where the agents are playing a zero-sum game.\n\n    Groups can be used in two ways, in practice:\n\n    - To define a multi-agent environment\n    - As a part of the *algorithm* (e.g. GRPO), when dealing with single-agent tasks.\n\n    **Picklability:** Implementations must be pickleable (via standard ``pickle``) to support\n    distributed rollout execution where builders are serialized and sent to remote workers.\n    Avoid storing live network connections, file handles, or other unpickleable objects as\n    fields. Use ``get_renderer()`` to create Renderers (which are automatically pickle-safe).\n    Store configuration strings (model names, connection params) and construct heavy objects\n    in ``make_envs()`` when possible. See ``HarborEnvGroupBuilder`` for a reference\n    implementation of the lazy-construction pattern.\n    \"\"\"\n\n    @abstractmethod\n    async def make_envs(self) -> Sequence[Env]:\n        pass\n\n    async def compute_group_rewards(\n        self, trajectory_group: list[Trajectory], env_group: Sequence[Env]\n    ) -> list[tuple[float, Metrics]]:\n        \"\"\"\n        This computes a final reward for each trajectory that depends on the whole group.\n        Note that there are also per-timestep rewards returned by the Env.step() method.\n        The total reward is the sum of the per-timestep rewards plus the final group reward\n        computed here. Defining a group reward is optional -- by default, the group reward\n        is 0 and we only use the per-timestep rewards.\n        \"\"\"\n        return [(0.0, {}) for _ in trajectory_group]\n\n    async def cleanup(self) -> None:\n        \"\"\"Clean up resources created by make_envs().\n\n        Called after rollouts and reward computation complete, regardless\n        of success or failure. Override this to release expensive resources\n        like cloud sandboxes, remote browsers, etc.\n\n        Default is a no-op. Implementations should be idempotent (safe to\n        call multiple times) and handle exceptions internally, as `do_group_rollout`\n        does not catch exceptions from this method.\n        \"\"\"\n        pass\n\n    def logging_tags(self) -> list[str]:\n        \"\"\"\n        This is just used for logging. We often want to aggregate metrics (like rewards\n        or episode lengths) per-environment, or across a group of related environments.\n\n        Most commonly, you'd return a short name for the environment, such as ['gsm'] for\n        grade school math. You also might want a few tags at different levels of granularity,\n        e.g., ['gsm', 'math', 'rlvr']\n        \"\"\"\n        return []\n\n\n@dataclass\nclass TrajectoryGroup:\n    \"\"\"\n    A group of trajectories, resulting from instantiating a group of environments using an\n    EnvGroupBuilder, doing a rollout for each environment, and computing the rewards.\n    \"\"\"\n\n    trajectories_G: list[Trajectory]\n    final_rewards_G: list[float]  # computed by the EnvGroupBuilder, looking at whole group\n    metrics_G: list[Metrics]\n\n    # Error tracking — populated by do_group_rollout when using error-tolerant strategies.\n    # Empty list means no trajectory errors occurred.\n    rollout_errors: list[RolloutError] = field(default_factory=list)\n\n    def get_total_rewards(self) -> list[float]:\n        \"\"\"\n        Get the total reward (i.e., the return) of each trajectory (episode) in the group.\n        The total reward is the sum of the per-timestep rewards plus the final group reward\n        computed by the EnvGroupBuilder.\n        \"\"\"\n        return [\n            sum(transition.reward for transition in trajectory.transitions) + final_reward\n            for trajectory, final_reward in safezip(self.trajectories_G, self.final_rewards_G)\n        ]\n\n\nclass RLDataset(ABC):\n    \"\"\"\n    A dataset that produces batches of EnvGroups. This is the kind of dataset used by\n    training algorithms.\n    \"\"\"\n\n    @abstractmethod\n    def get_batch(self, index: int) -> Sequence[EnvGroupBuilder]:\n        pass\n\n    @abstractmethod\n    def __len__(self) -> int:\n        pass\n\n\n@chz.chz\nclass RLDatasetBuilder:\n    \"\"\"\n    Abstract class for building RL datasets.\n    \"\"\"\n\n    @abstractmethod\n    async def __call__(self) -> tuple[RLDataset, RLDataset | None]:\n        \"\"\"\n        Return RLDataset (for training) and an optional RL dataset for testing\n        \"\"\"\n        pass\n"
  },
  {
    "path": "tinker_cookbook/sandbox/README.md",
    "content": "# Sandboxing\n\nThis directory contains code execution backends for sandboxed evaluation (e.g., grading code in RL environments).\n\nThere are currently two available backends: SandboxFusion for local execution and Modal for cloud execution.\n\n## Backends\n\n### SandboxFusion (local Docker)\n\n[Sandbox Fusion](https://bytedance.github.io/SandboxFusion/) is a Docker-based code execution sandbox. Start a local sandbox in Docker with:\n\n```bash\ndocker run -it -p 8080:8080 volcengine/sandbox-fusion:server-20250609\n```\n\nFor RL workloads, you may want higher concurrency. See [`recipes/code_rl/sandbox_config/local.yaml`](../recipes/code_rl/sandbox_config/local.yaml) for an example configuration that can be mounted with `-v`, and see [`recipes/code_rl/README.md`](../recipes/code_rl/README.md) for instructions on using it.\n\nIf you prefer not to use Docker, see the [Sandbox Fusion repository](https://github.com/bytedance/SandboxFusion?tab=readme-ov-file#installation) for manual setup.\n\nExample usage:\n\n```python\nfrom tinker_cookbook.sandbox import SandboxFusionClient\n\nclient = SandboxFusionClient()\nsuccess, response = await client.run(\n    code=\"print('hello')\",\n    files={\"data.txt\": \"some content\"},\n    timeout=30,\n)\nawait client.close()\n```\n\nEnvironment variables:\n\n- `SANDBOX_URL`: Endpoint URL (default: `http://localhost:8080/run_code`)\n- `SANDBOX_MAX_CONCURRENCY`: Max concurrent requests (default: 4)\n\n### Modal (cloud)\n\n[Modal Sandboxes](https://modal.com/products/sandboxes) provide cloud-based isolated execution environments. Requires authentication with: `modal token new`\n\nExample usage:\n\n```python\nfrom tinker_cookbook.sandbox.modal_sandbox import ModalSandbox, ModalSandboxPool\n\n# Single sandbox (conforms to SandboxInterface)\nsandbox = await ModalSandbox.create()\nawait sandbox.write_file(\"/workspace/code.py\", \"print('hello')\")\nresult = await sandbox.run_command(\"python /workspace/code.py\", workdir=\"/workspace\")\nprint(result.stdout)\nawait sandbox.cleanup()\n\n# Pool for concurrent execution (recommended for RL workloads)\npool = ModalSandboxPool(pool_size=32)\nresult = await pool.run_in_workdir(\n    files={\"code.py\": \"print('hello')\"},\n    command=[\"python\", \"code.py\"],\n)\nprint(result.stdout)\n```\n\nEnvironment variables:\n\n- `MODAL_POOL_SIZE`: Number of sandboxes in the pool (default: 32)\n"
  },
  {
    "path": "tinker_cookbook/sandbox/__init__.py",
    "content": "\"\"\"\nCode execution backends for sandboxed code evaluation.\n\nThe sandbox/ directory provides thin wrappers around different sandbox backends:\n- SandboxFusionClient: HTTP-based sandbox using SandboxFusion Docker container\n- ModalSandbox: Cloud sandbox using Modal's infrastructure\n\"\"\"\n\nfrom enum import StrEnum\n\nfrom tinker_cookbook.sandbox.sandbox_interface import (\n    SandboxInterface,\n    SandboxResult,\n    SandboxTerminatedError,\n)\nfrom tinker_cookbook.sandbox.sandboxfusion import SandboxFusionClient\n\n\nclass SandboxBackend(StrEnum):\n    SANDBOXFUSION = \"sandboxfusion\"\n    MODAL = \"modal\"\n\n\n__all__ = [\n    \"SandboxBackend\",\n    \"SandboxFusionClient\",\n    \"SandboxInterface\",\n    \"SandboxResult\",\n    \"SandboxTerminatedError\",\n]\n"
  },
  {
    "path": "tinker_cookbook/sandbox/modal_sandbox.py",
    "content": "\"\"\"\nThin wrapper around Modal Sandbox API.\n\nModal provides cloud-based sandboxed execution environments.\nRequires Modal authentication: `modal token new`\n\nConfiguration via environment variables:\n    MODAL_POOL_SIZE: Number of sandboxes in the pool (default: 32)\n    MODAL_CREATION_RATE_LIMIT: Max sandboxes created per second (default: 4)\n\nSee: https://modal.com/docs/guide/sandbox\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport contextlib\nimport logging\nimport os\nimport shlex\nimport uuid\n\ntry:\n    import modal\nexcept ImportError:\n    raise ImportError(\n        \"modal is required for ModalSandbox. \"\n        \"Install it with: uv pip install 'tinker-cookbook[modal] @ \"\n        \"git+https://github.com/thinking-machines-lab/tinker-cookbook.git@nightly'\"\n    ) from None\n\nfrom tinker_cookbook.exceptions import SandboxError\nfrom tinker_cookbook.sandbox.sandbox_interface import SandboxResult, SandboxTerminatedError\n\nlogger = logging.getLogger(__name__)\n\n\nasync def _read_stream_capped(stream: object, max_bytes: int) -> str:\n    \"\"\"Read a Modal async stream up to *max_bytes*, draining the rest to avoid blocking.\"\"\"\n    chunks: list[bytes] = []\n    total = 0\n    try:\n        async for chunk in stream:  # type: ignore[union-attr]\n            if isinstance(chunk, str):\n                chunk = chunk.encode()\n            remaining = max_bytes - total\n            if remaining <= 0:\n                break\n            chunks.append(chunk[:remaining])\n            total += len(chunk[:remaining])\n    except UnicodeDecodeError:\n        pass  # Modal internal decoding error — return what we have\n\n    # Drain any remaining data so the process can exit cleanly\n    try:\n        async for _ in stream:  # type: ignore[union-attr]\n            pass\n    except (UnicodeDecodeError, Exception):\n        pass\n\n    return b\"\".join(chunks).decode(\"utf-8\", errors=\"replace\")\n\n\ndef _is_sandbox_terminated(e: BaseException) -> bool:\n    \"\"\"Check if an exception indicates the sandbox has died.\"\"\"\n    type_name = type(e).__name__\n    if type_name == \"NotFoundError\":\n        return True\n    msg = str(e).lower()\n    return any(keyword in msg for keyword in (\"terminated\", \"died\", \"not found\"))\n\n\nclass ModalSandbox:\n    \"\"\"\n    Persistent Modal sandbox for code execution. Conforms to SandboxInterface.\n\n    Usage:\n        sandbox = await ModalSandbox.create()\n\n        await sandbox.write_file(\"/workspace/code.py\", \"print('hello')\")\n        result = await sandbox.run_command(\"python /workspace/code.py\")\n        print(result.stdout)\n\n        await sandbox.cleanup()\n    \"\"\"\n\n    def __init__(\n        self,\n        timeout: int,\n        image: modal.Image,\n        app: modal.App,\n        sandbox: modal.Sandbox,\n        max_stream_output_bytes: int = 128 * 1024,\n    ) -> None:\n        self._timeout = timeout  # Timeout for the entire Sandbox instance\n        self._image = image\n        self._app = app\n        self._sandbox = sandbox\n        self._max_stream_output_bytes = max_stream_output_bytes\n\n    @classmethod\n    async def create(\n        cls,\n        app_name: str = \"tinker-cookbook-runner\",\n        timeout: int = 600,\n        image: modal.Image | None = None,\n        max_stream_output_bytes: int = 128 * 1024,\n    ) -> ModalSandbox:\n        \"\"\"Create a new Modal sandbox.\"\"\"\n        image = image or modal.Image.debian_slim()\n        app = await modal.App.lookup.aio(app_name, create_if_missing=True)\n        sandbox = await modal.Sandbox.create.aio(app=app, image=image, timeout=timeout)\n        return cls(\n            timeout=timeout,\n            image=image,\n            app=app,\n            sandbox=sandbox,\n            max_stream_output_bytes=max_stream_output_bytes,\n        )\n\n    @property\n    def sandbox_id(self) -> str:\n        return self._sandbox.object_id\n\n    async def send_heartbeat(self) -> None:\n        await self._sandbox.exec.aio(\"true\")\n\n    async def run_command(\n        self,\n        command: str,\n        workdir: str | None = None,\n        timeout: int = 60,\n        max_output_bytes: int | None = None,\n    ) -> SandboxResult:\n        \"\"\"Run a shell command in the sandbox.\"\"\"\n        cap = max_output_bytes if max_output_bytes is not None else self._max_stream_output_bytes\n        try:\n            proc = await self._sandbox.exec.aio(\n                \"bash\", \"-lc\", command, timeout=timeout, workdir=workdir\n            )\n            stdout, stderr, exit_code = await asyncio.gather(\n                _read_stream_capped(proc.stdout, cap),\n                _read_stream_capped(proc.stderr, cap),\n                proc.wait.aio(),\n            )\n            return SandboxResult(stdout=stdout, stderr=stderr, exit_code=exit_code)\n        except Exception as e:\n            if _is_sandbox_terminated(e):\n                raise SandboxTerminatedError(str(e)) from e\n            return SandboxResult(stdout=\"\", stderr=str(e), exit_code=-1)\n\n    async def read_file(\n        self, path: str, max_bytes: int | None = None, timeout: int = 60\n    ) -> SandboxResult:\n        \"\"\"Read a file from the sandbox.\"\"\"\n        if max_bytes is not None:\n            cmd = f\"head -c {max_bytes} {shlex.quote(path)}\"\n        else:\n            cmd = f\"cat {shlex.quote(path)}\"\n        return await self.run_command(cmd, timeout=timeout)\n\n    async def write_file(\n        self,\n        path: str,\n        content: str | bytes = \"\",\n        executable: bool = False,\n        timeout: int = 60,\n    ) -> SandboxResult:\n        \"\"\"Write content to a file in the sandbox.\"\"\"\n        if isinstance(content, str):\n            content = content.encode()\n\n        dir_path = os.path.dirname(path)\n        quoted_path = shlex.quote(path)\n\n        cmd = f\"mkdir -p {shlex.quote(dir_path)} && cat > {quoted_path}\"\n        if executable:\n            cmd += f\" && chmod +x {quoted_path}\"\n\n        try:\n            proc = await self._sandbox.exec.aio(\"bash\", \"-lc\", cmd, timeout=timeout)\n\n            # Write content in 2 MiB chunks via stdin\n            chunk_size = 2 * 1024 * 1024\n            for i in range(0, len(content), chunk_size):\n                chunk = content[i : i + chunk_size]\n                try:\n                    proc.stdin.write(chunk)\n                except TypeError:\n                    proc.stdin.write(chunk.decode(\"utf-8\", errors=\"replace\"))\n                await proc.stdin.drain.aio()\n            proc.stdin.write_eof()\n            await proc.stdin.drain.aio()\n\n            stdout, stderr, exit_code = await asyncio.gather(\n                _read_stream_capped(proc.stdout, self._max_stream_output_bytes),\n                _read_stream_capped(proc.stderr, self._max_stream_output_bytes),\n                proc.wait.aio(),\n            )\n            return SandboxResult(stdout=stdout, stderr=stderr, exit_code=exit_code)\n        except Exception as e:\n            if _is_sandbox_terminated(e):\n                raise SandboxTerminatedError(str(e)) from e\n            return SandboxResult(stdout=\"\", stderr=str(e), exit_code=-1)\n\n    async def cleanup(self) -> None:\n        \"\"\"Terminate the Modal sandbox and wait for it to fully shut down.\"\"\"\n        await self._sandbox.terminate.aio()\n        with contextlib.suppress(modal.exception.SandboxTimeoutError):\n            await self._sandbox.wait.aio(raise_on_termination=False)\n\n\nclass ModalSandboxPool:\n    \"\"\"\n    Pool of Modal sandboxes for concurrent execution.\n\n    Each sandbox handles one request at a time. The pool manages\n    borrowing and returning sandboxes automatically.\n\n    Configuration via environment variables:\n        MODAL_POOL_SIZE: Number of sandboxes in the pool (default: 32)\n        MODAL_CREATION_RATE_LIMIT: Max sandboxes created per second (default: 4)\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        pool_size: int | None = None,  # Number of warm sandboxes to maintain during the job run.\n        sandbox_timeout_secs: int = 1200,  # Time after which a sandbox is terminated.\n        image: modal.Image | None = None,\n        app_name: str = \"tinker-cookbook-runner\",\n    ):\n        self._pool_size = pool_size or int(os.getenv(\"MODAL_POOL_SIZE\", \"32\"))\n        self._creation_rate_limit = int(os.getenv(\"MODAL_CREATION_RATE_LIMIT\", \"4\"))\n        self._sandbox_timeout_secs = sandbox_timeout_secs\n        self._image = image\n        self._app_name = app_name\n        self._terminated = False\n\n        self._warm_pool: asyncio.Queue[ModalSandbox] = asyncio.Queue()  # Warm pool of sandboxes.\n        self._to_terminate: list[ModalSandbox] = []  # Sandboxes pending termination.\n        self._active_count = 0  # Number of in-use sandboxes.\n\n        asyncio.create_task(self._maintain_pool())\n\n    async def _create(self) -> ModalSandbox:\n        return await ModalSandbox.create(\n            app_name=self._app_name, timeout=self._sandbox_timeout_secs, image=self._image\n        )\n\n    async def _maintain_pool(self) -> None:\n        \"\"\"Background task to handle all sandbox creation and termination.\"\"\"\n        while not self._terminated:\n            try:\n                await self._maintain_pool_step()\n            except Exception as e:\n                logger.error(f\"Error maintaining ModalSandboxPool: {e}\")\n            await asyncio.sleep(1.0)\n\n    async def _maintain_pool_step(self) -> None:\n        \"\"\"Single iteration of pool maintenance: terminate used sandboxes, create new ones.\"\"\"\n        # Batch terminate used sandboxes\n        if self._to_terminate:\n            to_terminate, self._to_terminate = self._to_terminate, []\n            await asyncio.gather(*(sb.cleanup() for sb in to_terminate))\n\n        # Create new sandboxes in parallel (respecting rate limit)\n        total = self._warm_pool.qsize() + self._active_count\n        need = min(self._creation_rate_limit, self._pool_size - total)\n        if need > 0:\n            new_sandboxes = await asyncio.gather(\n                *(self._create() for _ in range(need)),\n                return_exceptions=True,\n            )\n            for sb in new_sandboxes:\n                if isinstance(sb, BaseException):\n                    logger.error(f\"Error creating Modal sandbox: {sb}\")\n                else:\n                    await self._warm_pool.put(sb)\n\n    async def run_in_workdir(\n        self,\n        files: dict[str, str],\n        command: list[str],\n        timeout: int | None = None,\n    ) -> SandboxResult:\n        \"\"\"\n        Execute command with files using an available sandbox from the pool.\n        If all sandboxes are busy, waits until one becomes available.\n\n        Creates an isolated workdir, writes files, and runs the command.\n\n        Args:\n            files: Files to write {filename: content}\n            command: Command and arguments (e.g., [\"python\", \"run.py\"])\n            timeout: Execution timeout in seconds\n        \"\"\"\n        if self._terminated:\n            raise SandboxError(\"ModalSandboxPool has been terminated.\")\n\n        sandbox = await self._warm_pool.get()\n        self._active_count += 1\n\n        try:\n            workdir = f\"/workspace/{uuid.uuid4().hex[:12]}\"\n            result = await sandbox.run_command(\n                f\"mkdir -p {shlex.quote(workdir)}\", timeout=timeout or 60\n            )\n            if result.exit_code != 0:\n                return SandboxResult(\n                    stdout=\"\",\n                    stderr=f\"Failed to create workdir: {workdir}\",\n                    exit_code=result.exit_code,\n                )\n\n            if files:\n                await asyncio.gather(\n                    *(\n                        sandbox.write_file(f\"{workdir}/{filename}\", content)\n                        for filename, content in files.items()\n                    )\n                )\n            return await sandbox.run_command(\n                shlex.join(command), workdir=workdir, timeout=timeout or self._sandbox_timeout_secs\n            )\n        finally:\n            self._active_count -= 1\n            self._to_terminate.append(sandbox)\n\n    async def terminate(self) -> None:\n        \"\"\"Exit the pool and terminate all sandboxes.\"\"\"\n        self._terminated = True\n\n        # Wait for active sandboxes to finish and be added to _to_terminate\n        while self._active_count > 0:\n            await asyncio.sleep(0.5)\n\n        # Collect and terminate all sandboxes\n        all_sandboxes = list(self._to_terminate)\n        while not self._warm_pool.empty():\n            try:\n                all_sandboxes.append(self._warm_pool.get_nowait())\n            except asyncio.QueueEmpty:\n                break\n        await asyncio.gather(*(sb.cleanup() for sb in all_sandboxes))\n"
  },
  {
    "path": "tinker_cookbook/sandbox/sandbox_interface.py",
    "content": "\"\"\"Sandbox interface for pluggable code execution backends.\"\"\"\n\nfrom typing import Any, Protocol, runtime_checkable\n\nimport chz\n\nfrom tinker_cookbook.exceptions import SandboxError\n\n\n@chz.chz\nclass SandboxResult:\n    \"\"\"Result from a sandbox operation.\"\"\"\n\n    stdout: str\n    stderr: str\n    exit_code: int\n    metrics: dict[str, Any] = chz.field(default_factory=dict)\n\n\nclass SandboxTerminatedError(SandboxError):\n    \"\"\"Raised when a sandbox has been terminated or died unexpectedly.\"\"\"\n\n    pass\n\n\n@runtime_checkable\nclass SandboxInterface(Protocol):\n    \"\"\"Interface for a sandbox.\n\n    Implementations must provide: run_command, read_file, write_file,\n    send_heartbeat, and cleanup.\n    \"\"\"\n\n    @property\n    def sandbox_id(self) -> str:\n        \"\"\"Identifier for the sandbox instance (e.g. Modal object_id).\"\"\"\n        ...\n\n    async def send_heartbeat(self) -> None:\n        \"\"\"Send a heartbeat to keep the sandbox alive.\n\n        If the sandbox server does not support heartbeat, this method can be a no-op.\n        \"\"\"\n        ...\n\n    async def run_command(\n        self,\n        command: str,\n        workdir: str | None = None,\n        timeout: int = 60,\n        max_output_bytes: int | None = None,\n    ) -> SandboxResult:\n        \"\"\"Run a command in the sandbox.\n\n        Setting ``workdir=None`` will run the command in the default WORKDIR set\n        in the container image (Dockerfile).\n\n        Args:\n            command: Shell command string to execute.\n            workdir: Working directory for the command.\n            timeout: Timeout in seconds.\n            max_output_bytes: Cap stdout/stderr at this many bytes. When None,\n                implementation uses its default (e.g. 128 KB).\n        \"\"\"\n        ...\n\n    async def read_file(\n        self, path: str, max_bytes: int | None = None, timeout: int = 60\n    ) -> SandboxResult:\n        \"\"\"Read the content of a file from the sandbox.\n\n        Args:\n            path: Path to the file in the sandbox.\n            max_bytes: If set, only read up to this many bytes from the file.\n            timeout: Timeout in seconds for the read operation.\n        \"\"\"\n        ...\n\n    async def write_file(\n        self, path: str, content: str | bytes, executable: bool = False, timeout: int = 60\n    ) -> SandboxResult:\n        \"\"\"Write content to a file in the sandbox.\n\n        Args:\n            path: Destination path inside the sandbox.\n            content: File content (str or bytes).\n            executable: If True, make the file executable.\n            timeout: Timeout in seconds.\n        \"\"\"\n        ...\n\n    async def cleanup(self) -> None:\n        \"\"\"Clean up the sandbox.\"\"\"\n        ...\n\n\nclass SandboxResource:\n    \"\"\"Resource wrapping a SandboxInterface.\"\"\"\n\n    def __init__(self, sandbox: SandboxInterface):\n        self.sandbox: SandboxInterface = sandbox\n\n    async def send_heartbeat(self) -> None:\n        await self.sandbox.send_heartbeat()\n\n    async def cleanup(self) -> None:\n        await self.sandbox.cleanup()\n"
  },
  {
    "path": "tinker_cookbook/sandbox/sandboxfusion.py",
    "content": "\"\"\"\nThin wrapper around SandboxFusion HTTP API.\n\nSandboxFusion is a Docker-based code execution sandbox. Run it locally with:\n\n    docker run -it -p 8080:8080 volcengine/sandbox-fusion:server-20250609\n\nConfiguration via environment variables:\n    SANDBOX_URL: Endpoint URL (default: http://localhost:8080/run_code)\n    SANDBOX_MAX_CONCURRENCY: Max concurrent requests (default: 4)\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport base64\nimport os\nfrom typing import Any\n\nimport aiohttp\n\n\nclass SandboxFusionClient:\n    \"\"\"\n    Async HTTP client for SandboxFusion code execution.\n\n    Usage:\n        client = SandboxFusionClient()\n        success, response = await client.run(\n            code=\"print('hello')\",\n            files={\"data.txt\": \"some content\"},\n            timeout=30,\n        )\n        await client.close()\n    \"\"\"\n\n    def __init__(\n        self,\n        url: str | None = None,\n        max_concurrency: int | None = None,\n    ):\n        self._url = url or os.getenv(\"SANDBOX_URL\", \"http://localhost:8080/run_code\")\n        self._max_concurrency = max_concurrency or int(os.getenv(\"SANDBOX_MAX_CONCURRENCY\", \"4\"))\n        self._session: aiohttp.ClientSession | None = None\n        self._session_lock = asyncio.Lock()\n\n    async def _get_session(self) -> aiohttp.ClientSession:\n        \"\"\"Get or create shared HTTP session with connection pooling.\n\n        The TCPConnector limits concurrent connections to max_concurrency.\n        When all connections are busy, additional requests automatically wait\n        in a queue until a connection becomes available.\n        \"\"\"\n        async with self._session_lock:\n            if self._session is None or self._session.closed:\n                connector = aiohttp.TCPConnector(\n                    limit=self._max_concurrency,\n                    limit_per_host=self._max_concurrency,\n                )\n                timeout = aiohttp.ClientTimeout(total=6000)\n                self._session = aiohttp.ClientSession(\n                    connector=connector,\n                    timeout=timeout,\n                )\n            return self._session\n\n    async def run(\n        self,\n        code: str,\n        files: dict[str, str],\n        timeout: float,\n        language: str = \"python\",\n    ) -> tuple[bool, dict[str, Any]]:\n        \"\"\"\n        Execute code with supporting files in the sandbox.\n\n        Args:\n            code: Main code to execute (entry point)\n            files: Additional files to include {filename: content}\n            timeout: Execution timeout in seconds\n            language: Programming language (default: python)\n\n        Returns:\n            Tuple of (success: bool, response: dict)\n            - success is True only if status == \"Success\"\n            - response contains the full API response or error details\n        \"\"\"\n        encoded_files = {\n            k: base64.b64encode(v.encode(\"utf-8\")).decode(\"utf-8\") for k, v in files.items()\n        }\n\n        payload = {\n            \"code\": code,\n            \"language\": language,\n            \"run_timeout\": int(timeout),\n            \"files\": encoded_files,\n        }\n\n        try:\n            session = await self._get_session()\n            async with session.post(self._url, json=payload) as resp:\n                if resp.status != 200:\n                    error_text = await resp.text()\n                    return False, {\"error\": f\"HTTP {resp.status}: {error_text}\"}\n\n                data: dict[str, Any] = await resp.json()\n\n                if data.get(\"status\") == \"SandboxError\":\n                    return False, {\"error\": data.get(\"message\", \"SandboxError\"), **data}\n\n                success = data.get(\"status\") == \"Success\"\n                return success, data\n\n        except Exception as e:\n            return False, {\"error\": str(e)}\n\n    async def close(self) -> None:\n        \"\"\"Close the HTTP session and release resources.\"\"\"\n        async with self._session_lock:\n            if self._session is not None and not self._session.closed:\n                await self._session.close()\n                self._session = None\n"
  },
  {
    "path": "tinker_cookbook/scripts/merge_tinker_adapter_to_hf_model.py",
    "content": "\"\"\"Merge Tinker adapter weights to a HuggingFace model, and save the new model to a given path.\n\nPlease refer to the following documentation for how to download a Tinker sampler adapter weights: https://tinker-docs.thinkingmachines.ai/download-weights\n\nUsage:\npython merge_tinker_adapter_to_hf_model.py --hf-model <name_or_path_to_hf_model> --tinker-adapter-path <local_path_to_tinker_adapter_weights> --output-path <output_path_to_save_merged_model>\n\nNOTE: This script is a thin CLI wrapper around tinker_cookbook.weights.build_hf_model().\nFor programmatic use, prefer importing from tinker_cookbook.weights directly.\n\"\"\"\n\nimport argparse\nimport warnings\n\nfrom tinker_cookbook.weights import build_hf_model\n\n\ndef main():\n    warnings.warn(\n        \"This script is deprecated. \"\n        \"Use tinker_cookbook.weights.build_hf_model() instead:\\n\\n\"\n        \"    from tinker_cookbook import weights\\n\"\n        \"    weights.build_hf_model(\\n\"\n        \"        base_model='...', adapter_path='...', output_path='...'\\n\"\n        \"    )\\n\",\n        DeprecationWarning,\n        stacklevel=2,\n    )\n    parser = argparse.ArgumentParser(\n        description=\"Merge Tinker LoRA adapter weights into a HuggingFace model.\"\n    )\n    parser.add_argument(\n        \"--tinker-adapter-path\", type=str, required=True, help=\"Path to the Tinker adapter\"\n    )\n    parser.add_argument(\n        \"--hf-model\", type=str, required=True, help=\"HuggingFace model name or path\"\n    )\n    parser.add_argument(\n        \"--output-path\", type=str, required=True, help=\"Path to save the merged model\"\n    )\n    parser.add_argument(\n        \"--quantize\",\n        type=str,\n        default=None,\n        choices=[\"experts-fp8\"],\n        help=\"Output quantization method (e.g. 'experts-fp8' for FP8 routed experts)\",\n    )\n    parser.add_argument(\n        \"--serving-format\",\n        type=str,\n        default=None,\n        choices=[\"vllm\"],\n        help=\"Serving framework format for quantization metadata (e.g. 'vllm')\",\n    )\n    args = parser.parse_args()\n\n    build_hf_model(\n        base_model=args.hf_model,\n        adapter_path=args.tinker_adapter_path,\n        output_path=args.output_path,\n        quantize=args.quantize,\n        serving_format=args.serving_format,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tinker_cookbook/scripts/test_tool_calling_e2e.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nEnd-to-end test script for tool calling across different model families.\n\nThis script queries production models with tool-calling prompts and verifies\nthat tool calls are correctly rendered, generated, and parsed.\n\nNOT a unit test - requires API access and queries real models.\n\nUsage:\n    uv run python tinker_cookbook/scripts/test_tool_calling_e2e.py [--model MODEL_NAME]\n\"\"\"\n\nimport argparse\nimport asyncio\n\nimport tinker\n\nfrom tinker_cookbook.renderers import (\n    Message,\n    ToolSpec,\n    get_renderer,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\n\n# Sample tool specifications\nSAMPLE_TOOLS: list[ToolSpec] = [\n    {\n        \"name\": \"get_weather\",\n        \"description\": \"Get the current weather for a location\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"location\": {\"type\": \"string\", \"description\": \"City name, e.g. 'San Francisco'\"},\n                \"unit\": {\n                    \"type\": \"string\",\n                    \"enum\": [\"celsius\", \"fahrenheit\"],\n                    \"description\": \"Temperature unit\",\n                },\n            },\n            \"required\": [\"location\"],\n        },\n    },\n    {\n        \"name\": \"calculate\",\n        \"description\": \"Perform a mathematical calculation\",\n        \"parameters\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"expression\": {\"type\": \"string\", \"description\": \"Math expression to evaluate\"},\n            },\n            \"required\": [\"expression\"],\n        },\n    },\n]\n\n\n# Model configurations for testing\nMODEL_CONFIGS = [\n    {\n        \"model_name\": \"Qwen/Qwen3-8B\",\n        \"renderer_name\": \"qwen3\",\n    },\n    {\n        \"model_name\": \"Qwen/Qwen3-30B-A3B-Instruct-2507\",\n        \"renderer_name\": \"qwen3_instruct\",\n    },\n    {\n        \"model_name\": \"meta-llama/Llama-3.1-8B-Instruct\",\n        \"renderer_name\": \"llama3\",\n    },\n    {\n        \"model_name\": \"deepseek-ai/DeepSeek-V3.1\",\n        \"renderer_name\": \"deepseekv3\",\n    },\n    {\n        \"model_name\": \"moonshotai/Kimi-K2-Thinking\",\n        \"renderer_name\": \"kimi_k2\",\n    },\n    {\n        \"model_name\": \"openai/gpt-oss-20b\",\n        \"renderer_name\": \"gpt_oss_medium_reasoning\",\n    },\n]\n\n\ndef print_result(\n    model_name: str,\n    success: bool,\n    message: Message,\n    raw_response: str,\n):\n    \"\"\"Print formatted test result.\"\"\"\n    status = \"✓\" if success else \"✗\"\n    print(f\"\\n{'=' * 60}\")\n    print(f\"{status} {model_name}\")\n    print(f\"{'=' * 60}\")\n\n    if \"tool_calls\" in message and message[\"tool_calls\"]:  # noqa: RUF019\n        print(f\"Tool calls found: {len(message['tool_calls'])}\")\n        for i, tc in enumerate(message[\"tool_calls\"]):\n            print(f\"  [{i}] {tc.function.name}({tc.function.arguments})\")\n    else:\n        print(\"No tool calls found\")\n\n    print(f\"\\nContent: {message.get('content', '')[:200]}...\")\n    print(f\"\\nRaw response (first 500 chars):\\n{raw_response[:500]}\")\n\n\nasync def test_model(\n    service_client: tinker.ServiceClient,\n    model_name: str,\n    renderer_name: str,\n    tools: list[ToolSpec],\n    system_prompt: str,\n    user_prompt: str,\n) -> tuple[bool, Message, str]:\n    \"\"\"\n    Test tool calling for a single model.\n\n    Returns:\n        Tuple of (success, parsed_message, raw_response)\n    \"\"\"\n    print(f\"\\nTesting {model_name}...\")\n\n    # Get tokenizer and renderer\n    tokenizer = get_tokenizer(model_name)\n    renderer = get_renderer(renderer_name, tokenizer)\n\n    # Build messages using the unified interface\n    prefix_messages = renderer.create_conversation_prefix_with_tools(tools, system_prompt)\n    messages: list[Message] = prefix_messages + [{\"role\": \"user\", \"content\": user_prompt}]\n\n    # Build prompt\n    prompt = renderer.build_generation_prompt(messages)\n    stop_sequences = renderer.get_stop_sequences()\n\n    # Create sampling client\n    sampling_client = service_client.create_sampling_client(base_model=model_name)\n    result = await sampling_client.sample_async(\n        prompt=prompt,\n        num_samples=1,\n        sampling_params=tinker.SamplingParams(\n            stop=stop_sequences,\n            max_tokens=512,\n            temperature=0.0,  # Deterministic for testing\n        ),\n    )\n\n    # Parse response\n    response_tokens = result.sequences[0].tokens\n    raw_response = str(tokenizer.decode(response_tokens))\n    message, parse_success = renderer.parse_response(response_tokens)\n\n    # Check if we got tool calls\n    has_tool_calls = \"tool_calls\" in message and len(message.get(\"tool_calls\", [])) > 0\n    success = parse_success and has_tool_calls\n\n    return success, message, raw_response\n\n\nasync def main():\n    parser = argparse.ArgumentParser(description=\"Test tool calling across models\")\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        help=\"Specific model to test (default: test all)\",\n    )\n    parser.add_argument(\n        \"--prompt\",\n        type=str,\n        default=\"What's the weather like in San Francisco?\",\n        help=\"User prompt to send\",\n    )\n    args = parser.parse_args()\n\n    # Filter models if specific one requested\n    configs = MODEL_CONFIGS\n    if args.model:\n        configs = [c for c in configs if args.model in c[\"model_name\"]]\n        if not configs:\n            print(f\"No matching model found for: {args.model}\")\n            print(f\"Available models: {[c['model_name'] for c in MODEL_CONFIGS]}\")\n            return\n\n    print(\"=\" * 60)\n    print(\"Tool Calling End-to-End Test\")\n    print(\"=\" * 60)\n    print(f\"User prompt: {args.prompt}\")\n    print(f\"Models to test: {[c['model_name'] for c in configs]}\")\n\n    # Create service client (shared across all model tests)\n    service_client = tinker.ServiceClient()\n\n    system_prompt = \"You are a helpful assistant.\"\n    results = []\n    for config in configs:\n        try:\n            success, message, raw_response = await test_model(\n                service_client=service_client,\n                model_name=config[\"model_name\"],\n                renderer_name=config[\"renderer_name\"],\n                tools=SAMPLE_TOOLS,\n                system_prompt=system_prompt,\n                user_prompt=args.prompt,\n            )\n            print_result(config[\"model_name\"], success, message, raw_response)\n            results.append((config[\"model_name\"], success))\n        except Exception as e:\n            print(f\"\\n✗ {config['model_name']}: Error - {e}\")\n            results.append((config[\"model_name\"], False))\n\n    # Summary\n    print(\"\\n\" + \"=\" * 60)\n    print(\"Summary\")\n    print(\"=\" * 60)\n    passed = sum(1 for _, s in results if s)\n    total = len(results)\n    print(f\"Passed: {passed}/{total}\")\n    for model, success in results:\n        status = \"✓\" if success else \"✗\"\n        print(f\"  {status} {model}\")\n\n\nif __name__ == \"__main__\":\n    asyncio.run(main())\n"
  },
  {
    "path": "tinker_cookbook/supervised/__init__.py",
    "content": ""
  },
  {
    "path": "tinker_cookbook/supervised/common.py",
    "content": "import logging\n\nimport tinker\nimport torch\n\nfrom tinker_cookbook.exceptions import DataValidationError\n\nlogger = logging.getLogger(__name__)\n\n\ndef compute_mean_nll(\n    logprobs_list: list[tinker.TensorData], weights_list: list[tinker.TensorData]\n) -> float:\n    \"\"\"Compute weighted mean negative log likelihood.\"\"\"\n    total_weighted_logprobs = 0.0\n    total_weights = 0.0\n\n    for logprobs, weights in zip(logprobs_list, weights_list, strict=True):\n        logprobs_torch = logprobs.to_torch()\n        weights_torch = weights.to_torch()\n        total_weighted_logprobs += logprobs_torch.dot(weights_torch)\n        total_weights += weights_torch.sum()\n\n    if total_weights == 0:\n        logger.warning(\"No valid weights found for NLL computation\")\n        return float(\"nan\")\n\n    return float(-total_weighted_logprobs / total_weights)\n\n\ndef create_rightshifted_model_input_and_leftshifted_targets(\n    chunks: list[tinker.ModelInputChunk],\n) -> tuple[tinker.ModelInput, list[int]]:\n    \"\"\"\n    Given a full sequence of model input chunks, create\n     \"inputs\" (with last token removed); these are also list[ModelInputChunk] because text+images\n     \"targets\" (with first token removed); these are list[int] text tokens\n    \"\"\"\n    assert len(chunks) >= 1, \"must have at least one chunk\"\n\n    last_chunk = chunks[-1]\n    if not isinstance(last_chunk, tinker.types.EncodedTextChunk):\n        raise DataValidationError(\n            \"The last chunk must be a text chunk. This is because images are 0-loss anyways, so we should remove them beforehand.\"\n        )\n\n    total_length = sum(c.length for c in chunks)\n    if total_length < 2:\n        raise DataValidationError(\"need at least 2 tokens for input/target split\")\n\n    # Build input chunks: all but last, then append truncated last chunk\n    input_chunks: list[tinker.ModelInputChunk] = list(chunks[:-1])\n    if last_chunk.length > 1:\n        input_chunks.append(tinker.types.EncodedTextChunk(tokens=last_chunk.tokens[:-1]))\n\n    # Build target tokens: collect all tokens, then slice off first\n    all_tokens: list[int] = []\n    for chunk in chunks:\n        if isinstance(chunk, tinker.types.EncodedTextChunk):\n            all_tokens.extend(chunk.tokens)\n        else:\n            all_tokens.extend([0] * chunk.length)\n    target_tokens = all_tokens[1:]\n\n    return tinker.ModelInput(chunks=input_chunks), target_tokens\n\n\ndef datum_from_model_input_weights(\n    model_input: tinker.ModelInput,\n    weights: torch.Tensor,\n    max_length: int | None = None,\n) -> tinker.Datum:\n    \"\"\"\n    Create a Datum from a ModelInput and weights tensor.\n\n    Performs max_length truncation and next-token slicing to create input and target.\n    Text chunks can be truncated, but image chunks must be wholly discarded to stay\n    within max_length.\n\n    Args:\n        model_input: The model input containing a sequence of text and/or image chunks\n        weights: The weights tensor aligned with the model_input length\n        max_length: Optional maximum sequence length. If provided, truncates to this length.\n                   Image chunks are discarded entirely if they would exceed max_length.\n\n    Returns:\n        A Datum with model_input (input tokens) and loss_fn_inputs (target tokens and weights)\n    \"\"\"\n\n    model_input_chunks = list(model_input.chunks)\n\n    # Truncate to max_length by popping from end\n    if max_length is not None:\n        total_length = sum(chunk.length for chunk in model_input_chunks)\n\n        while total_length > max_length and model_input_chunks:\n            last = model_input_chunks[-1]\n            if isinstance(last, tinker.types.EncodedTextChunk):\n                overflow = total_length - max_length\n                if overflow < last.length:\n                    # Partial truncation of text chunk\n                    model_input_chunks[-1] = tinker.types.EncodedTextChunk(\n                        tokens=list(last.tokens[:-overflow])\n                    )\n                    total_length = max_length\n                else:\n                    # Remove entire text chunk\n                    model_input_chunks.pop()\n                    total_length -= last.length\n            else:\n                # Image chunk - must remove entirely\n                model_input_chunks.pop()\n                total_length -= last.length\n\n    # Remove trailing images (no text to predict after them)\n    while model_input_chunks and isinstance(\n        model_input_chunks[-1], (tinker.types.ImageChunk, tinker.types.ImageAssetPointerChunk)\n    ):\n        model_input_chunks.pop()\n\n    input_model_input, target_tokens = create_rightshifted_model_input_and_leftshifted_targets(\n        model_input_chunks\n    )\n    weights = weights[1 : len(target_tokens) + 1]\n\n    return tinker.Datum(\n        model_input=input_model_input,\n        loss_fn_inputs={\n            \"weights\": tinker.TensorData(\n                data=weights.tolist(),\n                dtype=\"float32\",\n                shape=list(weights.shape),\n            ),\n            \"target_tokens\": tinker.TensorData(\n                data=target_tokens,\n                dtype=\"int64\",\n                shape=[len(target_tokens)],\n            ),\n        },\n    )\n"
  },
  {
    "path": "tinker_cookbook/supervised/data.py",
    "content": "\"\"\"\nSupervised learning dataset implementations from HuggingFace datasets.\n\"\"\"\n\nimport json\nfrom collections.abc import Callable\nfrom typing import Any\n\nimport blobfile\nimport chz\nimport datasets\nimport tinker\n\nfrom tinker_cookbook.exceptions import DataFormatError, DataValidationError\nfrom tinker_cookbook.renderers import Message, Renderer, TrainOnWhat\nfrom tinker_cookbook.supervised.common import datum_from_model_input_weights\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilder, SupervisedDataset\n\n\ndef conversation_to_datum(\n    conversation: list[Message],\n    renderer: Renderer,\n    max_length: int | None,\n    train_on_what: TrainOnWhat = TrainOnWhat.ALL_ASSISTANT_MESSAGES,\n) -> tinker.Datum:\n    \"\"\"Common function to process a list of messages into a Datum.\"\"\"\n    model_input, weights = renderer.build_supervised_example(\n        conversation, train_on_what=train_on_what\n    )\n    return datum_from_model_input_weights(model_input, weights, max_length)\n\n\ndef _one_of(a: Any, b: Any) -> bool:\n    return (a is not None and b is None) or (a is None and b is not None)\n\n\nclass SupervisedDatasetFromHFDataset(SupervisedDataset):\n    def __init__(\n        self,\n        hf_dataset: datasets.Dataset,\n        batch_size: int,\n        map_fn: Callable[[dict], tinker.Datum] | None = None,\n        flatmap_fn: Callable[[dict], list[tinker.Datum]] | None = None,\n    ):\n        assert _one_of(map_fn, flatmap_fn), \"Only one of map_fn or flatmap_fn can be provided\"\n        self.hf_dataset = hf_dataset\n        self.shuffle_dataset = (\n            hf_dataset  # Keep a reference to the original dataset to avoid statefulness\n        )\n        self.batch_size = batch_size\n        self.map_fn = map_fn\n        self.flatmap_fn = flatmap_fn\n\n    def get_batch(self, index: int) -> list[tinker.Datum]:\n        rows = self.shuffle_dataset.select(\n            range(index * self.batch_size, (index + 1) * self.batch_size)\n        )\n        if self.map_fn is not None:\n            return [self.map_fn(row) for row in rows.to_list()]\n        else:\n            assert self.flatmap_fn is not None\n            return [datum for row in rows.to_list() for datum in self.flatmap_fn(row)]\n\n    def set_epoch(self, seed: int = 0):\n        self.shuffle_dataset = self.hf_dataset.shuffle(seed=seed)\n\n    def __len__(self) -> int:\n        return len(self.hf_dataset) // self.batch_size\n\n\nclass StreamingSupervisedDatasetFromHFDataset(SupervisedDataset):\n    def __init__(\n        self,\n        hf_dataset: datasets.IterableDataset,\n        batch_size: int,\n        length: int,\n        map_fn: Callable[[dict], tinker.Datum] | None = None,\n        flatmap_fn: Callable[[dict], list[tinker.Datum]] | None = None,\n        buffer_size: int = 10_000,\n    ):\n        assert _one_of(map_fn, flatmap_fn), \"Only one of map_fn or flatmap_fn can be provided\"\n        self.hf_dataset = hf_dataset.shuffle(seed=0, buffer_size=buffer_size).batch(\n            batch_size=batch_size, drop_last_batch=True\n        )\n        self.dataset_iterator = iter(self.hf_dataset)\n        self.index = -1\n        self.batch_size = batch_size\n        self.map_fn = map_fn\n        self.flatmap_fn = flatmap_fn\n        # We pass the length to the dataset, since streaming HF datasets don't have a length attribute\n        self.length = length\n\n    def get_batch(self, index: int) -> list[tinker.Datum]:\n        # Error on backward seeks\n        if index < self.index + 1:\n            raise DataValidationError(\n                f\"StreamingSupervisedDatasetFromHFDataset only supports forward iteration. \"\n                f\"Cannot seek backward from batch {self.index} to {index}.\"\n            )\n\n        # Skip forward if needed by consuming intermediate batches\n        batches_to_skip = index - self.index - 1\n        for _ in range(batches_to_skip):\n            next(self.dataset_iterator)\n            self.index += 1\n\n        # Get the actual batch\n        self.index = index\n        batch = next(self.dataset_iterator)\n        rows = [dict(zip(batch.keys(), values)) for values in zip(*batch.values())]\n        if self.map_fn is not None:\n            return [self.map_fn(row) for row in rows]\n        else:\n            assert self.flatmap_fn is not None\n            return [datum for row in rows for datum in self.flatmap_fn(row)]\n\n    def set_epoch(self, seed: int = 0):\n        self.hf_dataset.set_epoch(seed)\n        self.dataset_iterator = iter(self.hf_dataset)\n        self.index = -1\n\n    def __len__(self) -> int:\n        return self.length // self.batch_size\n\n\n@chz.chz\nclass FromConversationFileBuilder(ChatDatasetBuilder):\n    file_path: str\n    test_size: int = 0\n    shuffle_seed: int = 0\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        # Load conversations from JSONL file\n        conversations = []\n        with blobfile.BlobFile(self.file_path, \"r\", streaming=False) as f:\n            for line in f:\n                data = json.loads(line.strip())\n                if \"messages\" not in data:\n                    raise DataFormatError(\n                        f\"Each line in the JSONL file must contain a 'messages' field. Got: {data.keys()}\"\n                    )\n                conversations.append(data)\n\n        # Create HuggingFace dataset from the loaded data\n        dataset = datasets.Dataset.from_list(conversations)\n\n        # Shuffle if seed is provided\n        if self.shuffle_seed is not None:\n            dataset = dataset.shuffle(seed=self.shuffle_seed)\n\n        # Split into train and test\n        if self.test_size > 0 and len(dataset) > self.test_size:\n            test_ds = dataset.take(self.test_size)\n            train_ds = dataset.skip(self.test_size)\n        else:\n            # If test_size is 0 or dataset is too small, use all data for training\n            train_ds = dataset\n            test_ds = None\n\n        # Use train_on_what from common_config if provided, otherwise use default\n        train_on_what = (\n            TrainOnWhat(self.common_config.train_on_what)\n            if self.common_config.train_on_what\n            else TrainOnWhat.ALL_ASSISTANT_MESSAGES\n        )\n\n        # Define mapping function\n        def map_fn(row: dict) -> tinker.Datum:\n            return conversation_to_datum(\n                row[\"messages\"], self.renderer, self.common_config.max_length, train_on_what\n            )\n\n        # Create supervised dataset\n        supervised_dataset = SupervisedDatasetFromHFDataset(\n            train_ds, batch_size=self.common_config.batch_size, map_fn=map_fn\n        )\n\n        # Create evaluator if we have test data\n        if test_ds is not None:\n            test_dataset = SupervisedDatasetFromHFDataset(\n                test_ds, batch_size=len(test_ds), map_fn=map_fn\n            )\n        else:\n            test_dataset = None\n\n        return supervised_dataset, test_dataset\n"
  },
  {
    "path": "tinker_cookbook/supervised/nll_evaluator.py",
    "content": "import itertools\n\nimport tinker\n\nfrom tinker_cookbook.eval.evaluators import TrainingClientEvaluator\nfrom tinker_cookbook.supervised.common import compute_mean_nll\nfrom tinker_cookbook.supervised.types import SupervisedDataset\n\n\nclass NLLEvaluator(TrainingClientEvaluator):\n    def __init__(self, data: list[tinker.Datum], name: str = \"test\"):\n        self.name = name\n        self.data = data\n\n    async def __call__(self, training_client: tinker.TrainingClient) -> dict[str, float]:\n        future = await training_client.forward_async(self.data, loss_fn=\"cross_entropy\")\n        result = await future.result_async()\n        logprobs = [x[\"logprobs\"] for x in result.loss_fn_outputs]\n        weights = [datum.loss_fn_inputs[\"weights\"] for datum in self.data]\n        nll = compute_mean_nll(logprobs, weights)\n        key = f\"{self.name}/nll\"\n        return {key: nll}\n\n    @classmethod\n    def from_dataset(cls, dataset: SupervisedDataset, name: str = \"test\") -> \"NLLEvaluator\":\n        all_data = list(itertools.chain(*[dataset.get_batch(i) for i in range(len(dataset))]))\n        return cls(all_data, name=name)\n"
  },
  {
    "path": "tinker_cookbook/supervised/resume_test.py",
    "content": "\"\"\"Test for checkpoint resume functionality in supervised training.\"\"\"\n\nimport asyncio\nimport contextlib\nimport json\nimport os\nimport tempfile\nfrom typing import Any\nfrom unittest.mock import MagicMock, patch\n\nfrom tinker_cookbook import checkpoint_utils, renderers\nfrom tinker_cookbook.recipes.chat_sl import chat_datasets\nfrom tinker_cookbook.supervised import train\nfrom tinker_cookbook.supervised.types import ChatDatasetBuilderCommonConfig\nfrom tinker_cookbook.utils.file_utils import read_jsonl\n\n\nclass StopTrainingException(Exception):\n    \"\"\"Exception to stop training at a specific step.\"\"\"\n\n\ndef create_mock_logger_with_jsonl(\n    log_path: str,\n    interrupt_at_step: int | None = None,\n    interrupt_exception_class: type[Exception] | None = None,\n    metrics_filename: str = \"metrics.jsonl\",\n) -> MagicMock:\n    \"\"\"Create a mock logger that writes metrics to JSONL and optionally interrupts at a specific step.\"\"\"\n    mock_logger = MagicMock()\n\n    def log_metrics(metrics: dict[str, Any], step: int):\n        jsonl_path = os.path.join(log_path, metrics_filename)\n        with open(jsonl_path, \"a\") as f:\n            f.write(json.dumps({\"step\": step, **metrics}) + \"\\n\")\n            print(f\"Step {step} metrics: {metrics}\")\n\n        if interrupt_at_step is not None and step == interrupt_at_step:\n            if interrupt_exception_class is None:\n                raise ValueError(\n                    \"interrupt_exception_class must be provided if interrupt_at_step is set\"\n                )\n            raise interrupt_exception_class(f\"Interrupting at step {step}\")\n\n    mock_logger.log_metrics = log_metrics\n    mock_logger.close = MagicMock()\n    mock_logger.get_logger_url = MagicMock(return_value=None)\n\n    return mock_logger\n\n    pass\n\n\ndef checkpoint_resume():\n    interrupt_step = 8\n    with tempfile.TemporaryDirectory() as tmpdir:\n        log_path = tmpdir\n        os.makedirs(log_path, exist_ok=True)\n\n        # Use the real NoRobots dataset with a small batch size\n        model_name = \"meta-llama/Llama-3.2-1B\"\n        common_config = ChatDatasetBuilderCommonConfig(\n            model_name_for_tokenizer=model_name,\n            renderer_name=\"role_colon\",\n            max_length=1024,\n            batch_size=32,\n            train_on_what=renderers.TrainOnWhat.ALL_ASSISTANT_MESSAGES,\n        )\n\n        # Create config\n        config = train.Config(\n            log_path=log_path,\n            model_name=model_name,\n            dataset_builder=chat_datasets.NoRobotsBuilder(common_config=common_config),\n            num_epochs=1,\n            save_every=5,\n            eval_every=0,\n            infrequent_eval_every=0,\n            wandb_project=None,\n            lora_rank=16,\n            learning_rate=1e-5,\n        )\n\n        # Ensure interrupt happens after checkpoint\n        assert interrupt_step > config.save_every, (\n            f\"interrupt_step ({interrupt_step}) must be > save_every ({config.save_every}) \"\n            \"to test checkpoint resume\"\n        )\n\n        # First run - stop at interrupt_step\n        with patch(\"tinker_cookbook.utils.ml_log.setup_logging\") as mock_setup_logging:\n            mock_logger = create_mock_logger_with_jsonl(\n                log_path=log_path,\n                interrupt_at_step=interrupt_step,\n                interrupt_exception_class=StopTrainingException,\n            )\n            mock_setup_logging.return_value = mock_logger\n\n            # Run until exception\n            with contextlib.suppress(StopTrainingException):\n                asyncio.run(train.main(config))\n\n        # Verify checkpoint was saved at step 5\n        checkpoints = checkpoint_utils.load_checkpoints_file(log_path)\n        assert len(checkpoints) > 0, \"Should have at least one checkpoint\"\n        assert checkpoints[0].name == \"000005\", \"First checkpoint should be at step 5\"\n\n        # Read first run metrics\n        first_run_metrics = read_jsonl(os.path.join(log_path, \"metrics.jsonl\"))\n\n        # Second run - resume from checkpoint\n        with patch(\"tinker_cookbook.utils.ml_log.setup_logging\") as mock_setup_logging:\n            mock_logger2 = create_mock_logger_with_jsonl(\n                log_path=log_path,\n                metrics_filename=\"metrics_run2.jsonl\",\n                interrupt_at_step=interrupt_step,\n                interrupt_exception_class=StopTrainingException,\n            )\n            mock_setup_logging.return_value = mock_logger2\n\n            with contextlib.suppress(StopTrainingException):\n                asyncio.run(train.main(config))\n\n        # Read second run metrics\n        second_run_metrics = read_jsonl(os.path.join(log_path, \"metrics_run2.jsonl\"))\n\n        # Extract losses\n        first_losses = {\n            m[\"step\"]: m[\"train_mean_nll\"] for m in first_run_metrics if \"train_mean_nll\" in m\n        }\n        second_losses = {\n            m[\"step\"]: m[\"train_mean_nll\"] for m in second_run_metrics if \"train_mean_nll\" in m\n        }\n\n        overlap_steps = [5, 6, 7]\n        # Check that steps 6 and 7 have approximately the same losses in both runs\n        # (We resumed from checkpoint at step 5, so steps 6 and 7 should be similar)\n        for step in overlap_steps:\n            assert step in first_losses, f\"Step {step} missing from first run\"\n            assert step in second_losses, f\"Step {step} missing from second run\"\n\n            # Losses should be very close (within 5% relative difference)\n            loss1 = first_losses[step]\n            loss2 = second_losses[step]\n            relative_diff = abs(loss1 - loss2) / max(abs(loss1), abs(loss2))\n            assert relative_diff < 0.01, (\n                f\"Loss at step {step} should be similar: \"\n                f\"{loss1} vs {loss2} (relative diff: {relative_diff:.2%})\"\n            )\n\n        print(\"✓ Test passed: training resumed correctly from checkpoint\")\n        print(f\"  First run losses at steps 5-7: {[first_losses[i] for i in overlap_steps]}\")\n        print(f\"  Second run losses at steps 5-7: {[second_losses[i] for i in overlap_steps]}\")\n\n\nif __name__ == \"__main__\":\n    checkpoint_resume()\n"
  },
  {
    "path": "tinker_cookbook/supervised/train.py",
    "content": "\"\"\"\nSupervised fine-tuning (SFT)\n\nThis module implements a pipelined supervised learning training loop. For background on\nwhy we pipeline requests, see https://tinker-docs.thinkingmachines.ai/under-the-hood.\nFor a minimal, pedagogical example of SL training without these optimizations,\nrefer to `tinker_cookbook/recipes/sl_loop.py`.\n\"\"\"\n\nimport asyncio\nimport logging\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport chz\nimport tinker\nfrom tinker.lib.public_interfaces import APIFuture\n\nfrom tinker_cookbook import checkpoint_utils, model_info\nfrom tinker_cookbook.display import colorize_example\nfrom tinker_cookbook.eval.evaluators import (\n    Evaluator,\n    EvaluatorBuilder,\n    SamplingClientEvaluator,\n    TrainingClientEvaluator,\n)\nfrom tinker_cookbook.exceptions import ConfigurationError\nfrom tinker_cookbook.supervised.common import compute_mean_nll\nfrom tinker_cookbook.supervised.nll_evaluator import NLLEvaluator\nfrom tinker_cookbook.supervised.types import SupervisedDatasetBuilder\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.utils import ml_log, trace\nfrom tinker_cookbook.utils.lr_scheduling import LRSchedule, compute_schedule_lr_multiplier\n\nlogger = logging.getLogger(__name__)\n\n\n@chz.chz\nclass Config:\n    \"\"\"Configuration for supervised fine-tuning.\"\"\"\n\n    # Required parameters\n    log_path: str = chz.field(munger=lambda _, s: str(Path(s).expanduser()))\n    model_name: str\n    load_checkpoint_path: str | None = None\n    renderer_name: str | None = None\n    dataset_builder: SupervisedDatasetBuilder\n\n    # Training parameters\n    learning_rate: float = 1e-4\n    lr_schedule: LRSchedule = \"linear\"\n    num_epochs: int = 1\n\n    # Model parameters\n    lora_rank: int = 32\n\n    # Infrastructure parameters\n    base_url: str | None = None\n\n    # Checkpointing and evaluation (0 = disabled for *_every fields)\n    evaluator_builders: list[EvaluatorBuilder] = chz.field(default_factory=list)\n    infrequent_evaluator_builders: list[EvaluatorBuilder] = chz.field(default_factory=list)\n    save_every: int = 20\n    eval_every: int = 10\n    infrequent_eval_every: int = 100\n    # Periodic checkpoints use this TTL; the final checkpoint is kept indefinitely.\n    ttl_seconds: int | None = 604800  # 7 days\n\n    # Adam optimizer parameters\n    adam_beta1: float = 0.9\n    adam_beta2: float = 0.95\n    adam_eps: float = 1e-8\n\n    # Logging parameters\n    wandb_project: str | None = None\n    wandb_name: str | None = None\n\n    enable_trace: bool = False\n    span_chart_every: int = 0\n\n    # Maximum number of training steps. If None, train for num_epochs * n_batches.\n    max_steps: int | None = None\n\n\n@dataclass\nclass SubmittedBatch:\n    fwd_bwd_future: APIFuture[tinker.ForwardBackwardOutput]\n    optim_step_future: APIFuture[tinker.OptimStepResponse]\n    metrics: dict[str, int | float | str]\n    data: list\n    step: int\n    epoch_idx: int\n    batch_idx: int\n    eval_metrics: dict[str, float] | None = None\n    infrequent_eval_metrics: dict[str, float] | None = None\n\n\n@trace.scope\nasync def run_evals(\n    evaluators: list[Evaluator],\n    training_client: tinker.TrainingClient,\n    step: int,\n) -> dict[str, float]:\n    \"\"\"Evaluate the current model weights and prefix results with ``test/``.\n\n    The helper is called immediately before optimizer step `step` is submitted, so it\n    measures the weights produced after step `step-1` (or the initial weights for step 0).\n    Training-client evaluators run against the mutable training client, while sampling\n    evaluators request a fresh `SamplingClient` snapshot via\n    `save_weights_and_get_sampling_client_async` to ensure their work uses a fixed\n    checkpoint. Returned metrics are prefixed with ``test/`` so they can be logged next\n    to the same-step training metrics.\n    \"\"\"\n    trace.update_scope_context({\"step\": step})\n\n    metrics = {}\n    sampling_client = None\n\n    @trace.scope\n    async def run_evaluator(evaluator: Evaluator) -> dict[str, float]:\n        trace.update_scope_context(\n            {\n                \"step\": step,\n                \"evaluator_name\": type(evaluator).__name__,\n            }\n        )\n        if isinstance(evaluator, TrainingClientEvaluator):\n            trace.update_scope_context({\"evaluator_type\": \"TrainingClientEvaluator\"})\n            return await evaluator(training_client)\n        elif isinstance(evaluator, SamplingClientEvaluator):\n            trace.update_scope_context({\"evaluator_type\": \"SamplingClientEvaluator\"})\n            # Create sampling client lazily, only when needed\n            nonlocal sampling_client\n            if sampling_client is None:\n                # Snapshot the current pre-step weights and create a new sampling client.\n                sampling_client = await training_client.save_weights_and_get_sampling_client_async(\n                    f\"evals_step_{step}\"\n                )\n            return await evaluator(sampling_client)\n        else:\n            raise ConfigurationError(f\"Unknown evaluator type: {type(evaluator)}\")\n\n    for evaluator in evaluators:\n        eval_metrics = await run_evaluator(evaluator)\n        # Add test/ prefix to all metrics\n        metrics.update(eval_metrics)\n\n    return metrics\n\n\n@trace.scope\nasync def main(config: Config):\n    \"\"\"Run the standard supervised learning loop used by the supervised recipes.\n\n    Responsibilities:\n    1. Initialize logging, build the dataset/evaluator objects, construct (or resume) the\n       training client, and determine the ``epoch``/``batch`` indices to start from.\n    2. Iterate over batches: fetch data, optionally run evaluations before submitting the\n       optimizer step (so they observe pre-step weights), issue `forward_backward` and\n       `optim_step` requests, and log metrics once the futures resolve.\n    3. Save checkpoints at the configured cadence so runs can resume or export weights,\n       then emit a final checkpoint when training completes.\n\n    Training and evaluation metrics share the same ``step`` index to keep dashboards easy\n    to read.\n    \"\"\"\n    resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)\n    if resume_info:\n        start_epoch = resume_info.epoch or 0\n        start_batch = resume_info.batch\n    else:\n        start_epoch = 0\n        start_batch = 0\n    # (start_epoch, start_batch) now represent the next batch to execute if resuming.\n\n    ml_logger = ml_log.setup_logging(\n        log_dir=config.log_path,\n        wandb_project=config.wandb_project,\n        wandb_name=config.wandb_name,\n        config=config,\n        do_configure_logging_module=True,\n    )\n    if config.enable_trace:\n        # Get and rename the current (main) task\n        current_task = asyncio.current_task()\n        if current_task is not None:\n            current_task.set_name(\"main\")\n        trace_events_path = str(Path(config.log_path) / \"trace_events.jsonl\")\n        logger.info(f\"Tracing is enabled. Trace events will be saved to {trace_events_path}\")\n        logger.info(\n            f\"Run `python tinker_cookbook/utils/trace.py {trace_events_path} trace.json` and visualize in chrome://tracing or https://ui.perfetto.dev/\"\n        )\n        trace.trace_init(output_file=trace_events_path)\n\n    service_client = tinker.ServiceClient(base_url=config.base_url)\n\n    user_metadata: dict[str, str] = {}\n    if wandb_link := ml_logger.get_logger_url():\n        user_metadata[\"wandb_link\"] = wandb_link\n    checkpoint_utils.add_renderer_name_to_user_metadata(user_metadata, config.renderer_name)\n    model_info.warn_if_renderer_not_recommended(config.model_name, config.renderer_name)\n\n    if resume_info:\n        # Resuming interrupted training - load optimizer state for proper continuation\n        await checkpoint_utils.check_renderer_name_for_checkpoint_async(\n            service_client, resume_info.state_path, config.renderer_name\n        )\n        training_client = (\n            await service_client.create_training_client_from_state_with_optimizer_async(\n                resume_info.state_path, user_metadata=user_metadata\n            )\n        )\n        logger.info(f\"Resumed training from {resume_info.state_path}\")\n    elif config.load_checkpoint_path:\n        # Starting fresh from a checkpoint - load weights only (fresh optimizer)\n        await checkpoint_utils.check_renderer_name_for_checkpoint_async(\n            service_client, config.load_checkpoint_path, config.renderer_name\n        )\n        training_client = await service_client.create_training_client_from_state_async(\n            config.load_checkpoint_path, user_metadata=user_metadata\n        )\n        logger.info(f\"Loaded weights from {config.load_checkpoint_path}\")\n    else:\n        training_client = await service_client.create_lora_training_client_async(\n            base_model=config.model_name,\n            rank=config.lora_rank,\n            user_metadata=user_metadata,\n        )\n\n    dataset, maybe_test_dataset = config.dataset_builder()\n    n_batches = len(dataset)\n    total_steps = n_batches * config.num_epochs\n    if config.max_steps is not None:\n        total_steps = min(total_steps, config.max_steps)\n    progress_denominator = total_steps if total_steps > 0 else 1\n    tokenizer = get_tokenizer(config.model_name)\n\n    evaluators = [evaluator() for evaluator in config.evaluator_builders]\n    if maybe_test_dataset is not None:\n        evaluators.append(NLLEvaluator.from_dataset(maybe_test_dataset))\n\n    infrequent_evaluators = [evaluator() for evaluator in config.infrequent_evaluator_builders]\n    logger.info(\n        f\"Training for {n_batches} batches x {config.num_epochs} epochs = {n_batches * config.num_epochs} steps\"\n    )\n\n    @trace.scope\n    async def submit_batch(epoch_idx: int, batch_idx: int) -> SubmittedBatch:\n        step = epoch_idx * n_batches + batch_idx\n        trace.update_scope_context({\"step\": step})\n\n        metrics: dict[str, int | float | str] = {\"epoch\": epoch_idx}\n        metrics[\"progress\"] = step / progress_denominator\n\n        learning_rate = config.learning_rate * compute_schedule_lr_multiplier(\n            lr_schedule=config.lr_schedule,\n            step=step,\n            total_steps=total_steps,\n        )\n        metrics[\"learning_rate\"] = learning_rate\n\n        adam_params = tinker.AdamParams(\n            learning_rate=learning_rate,\n            beta1=config.adam_beta1,\n            beta2=config.adam_beta2,\n            eps=config.adam_eps,\n        )\n\n        async with trace.scope_span(\"get_batch\"):\n            data = dataset.get_batch(batch_idx)\n        if data:\n            logger.info(colorize_example(data[0], tokenizer))\n\n        # Trigger evaluations BEFORE submitting training operations so they snapshot pre-step weights\n        eval_metrics = None\n        if evaluators and config.eval_every > 0 and step % config.eval_every == 0:\n            async with trace.scope_span(\"evals\"):\n                eval_metrics = await run_evals(evaluators, training_client, step)\n\n        infrequent_eval_metrics = None\n        if (\n            infrequent_evaluators\n            and config.infrequent_eval_every > 0\n            and step % config.infrequent_eval_every == 0\n        ):\n            async with trace.scope_span(\"infrequent_evals\"):\n                infrequent_eval_metrics = await run_evals(\n                    infrequent_evaluators, training_client, step\n                )\n\n        fwd_bwd_future = await training_client.forward_backward_async(data, loss_fn=\"cross_entropy\")\n        optim_step_future = await training_client.optim_step_async(adam_params)\n\n        return SubmittedBatch(\n            fwd_bwd_future=fwd_bwd_future,\n            optim_step_future=optim_step_future,\n            metrics=metrics,\n            data=data,\n            step=step,\n            epoch_idx=epoch_idx,\n            batch_idx=batch_idx,\n            eval_metrics=eval_metrics,\n            infrequent_eval_metrics=infrequent_eval_metrics,\n        )\n\n    @trace.scope\n    async def finish_batch(submitted: SubmittedBatch):\n        trace.update_scope_context({\"step\": submitted.step})\n\n        metrics = submitted.metrics\n        metrics[\"progress\"] = min((submitted.step + 1) / progress_denominator, 1.0)\n\n        if config.save_every > 0 and submitted.step % config.save_every == 0 and submitted.step > 0:\n            async with trace.scope_span(\"save_checkpoint\"):\n                # Enqueue a checkpoint save after the forward/backward and optimizer\n                # requests for this step; the snapshot will reflect post-step weights.\n                await checkpoint_utils.save_checkpoint_async(\n                    training_client=training_client,\n                    name=f\"{submitted.step:06d}\",\n                    log_path=config.log_path,\n                    loop_state={\"epoch\": submitted.epoch_idx, \"batch\": submitted.batch_idx},\n                    kind=\"both\",\n                    ttl_seconds=config.ttl_seconds,\n                )\n\n        async with trace.scope_span(\"step\"):\n            fwd_bwd_result = await submitted.fwd_bwd_future.result_async()\n            optim_step_result = await submitted.optim_step_future.result_async()\n\n        if optim_step_result.metrics:\n            metrics.update(optim_step_result.metrics)\n\n        logprobs = [x[\"logprobs\"] for x in fwd_bwd_result.loss_fn_outputs]\n        weights = [datum.loss_fn_inputs[\"weights\"] for datum in submitted.data]\n        train_nll = compute_mean_nll(logprobs, weights)\n\n        metrics.update(\n            num_sequences=len(submitted.data),\n            num_tokens=sum(datum.model_input.length for datum in submitted.data),\n            num_loss_tokens=sum(\n                sum(datum.loss_fn_inputs[\"weights\"].data) for datum in submitted.data\n            ),\n            train_mean_nll=train_nll,\n        )\n        # Merge evaluation metrics gathered before the training step was submitted\n        if submitted.eval_metrics is not None:\n            metrics.update(submitted.eval_metrics)\n\n        if submitted.infrequent_eval_metrics is not None:\n            metrics.update(submitted.infrequent_eval_metrics)\n\n    pending_batch: SubmittedBatch | None = None\n    log_path = Path(config.log_path)\n\n    async def finish_and_log(submitted: SubmittedBatch, window: trace.IterationWindow) -> None:\n        \"\"\"Finish a batch, merge timing metrics, and log.\"\"\"\n        await finish_batch(submitted)\n        submitted.metrics.update(window.get_timing_metrics())\n        window.write_spans_jsonl(log_path / \"timing_spans.jsonl\", step=submitted.step)\n        if config.span_chart_every > 0 and submitted.step % config.span_chart_every == 0:\n            trace.save_gantt_chart_html(\n                window, submitted.step, log_path / f\"timing_gantt_{submitted.step:06d}.html\"\n            )\n        ml_logger.log_metrics(metrics=submitted.metrics, step=submitted.step)\n\n    reached_max_steps = False\n    for epoch_idx in range(start_epoch, config.num_epochs):\n        logger.info(f\"Starting epoch {epoch_idx}\")\n        dataset.set_epoch(seed=epoch_idx)\n\n        start_batch_idx = start_batch if epoch_idx == start_epoch else 0\n        for batch_idx in range(start_batch_idx, n_batches):\n            step = epoch_idx * n_batches + batch_idx\n            if config.max_steps is not None and step >= config.max_steps:\n                reached_max_steps = True\n                break\n            with trace.trace_iteration(step=step) as window:\n                submitted_batch = await submit_batch(epoch_idx, batch_idx)\n                if pending_batch is not None:\n                    await finish_and_log(pending_batch, window)\n            pending_batch = submitted_batch\n        if reached_max_steps:\n            break\n\n    if pending_batch is not None:\n        with trace.trace_iteration(step=pending_batch.step) as window:\n            await finish_and_log(pending_batch, window)\n\n    did_train = start_epoch < config.num_epochs and (\n        config.max_steps is None or start_epoch * n_batches + start_batch < config.max_steps\n    )\n    if did_train:\n        await checkpoint_utils.save_checkpoint_async(\n            training_client=training_client,\n            name=\"final\",\n            log_path=config.log_path,\n            kind=\"both\",\n            loop_state={\"epoch\": config.num_epochs, \"batch\": 0},\n            ttl_seconds=None,\n        )\n    else:\n        logger.info(\"Training was already complete; nothing to do\")\n\n    ml_logger.close()\n    logger.info(\"Training completed successfully\")\n\n\nif __name__ == \"__main__\":\n    chz.nested_entrypoint(lambda config: asyncio.run(main(config)), allow_hyphens=True)\n"
  },
  {
    "path": "tinker_cookbook/supervised/types.py",
    "content": "\"\"\"\nBasic interfaces and types for supervised training.\n\"\"\"\n\nimport logging\n\nimport chz\nimport tinker\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.tokenizer_utils import Tokenizer, get_tokenizer\n\nlogger = logging.getLogger(__name__)\n\n\nclass SupervisedDataset:\n    \"\"\"\n    Dataset used for supervised learning\n    \"\"\"\n\n    def get_batch(self, index: int) -> list[tinker.Datum]:\n        raise NotImplementedError\n\n    def __len__(self) -> int:\n        raise NotImplementedError\n\n    def set_epoch(self, seed: int = 0):\n        \"\"\"Tell the dataset that we're on the given epoch of training.\n        Datasets can decide what to do with this information, but for best\n        results with multi-epoch training, you might want to shuffle differently each epoch,\n        though results on whether this helps are inconclusive.\n        \"\"\"\n        logger.warning(\n            \"set_epoch called, but shuffling is not implemented for %s\",\n            self.__class__.__name__,\n        )\n\n\n@chz.chz\nclass SupervisedDatasetBuilder:\n    \"\"\"\n    A config class that knows how to construct a supervised dataset. This dataset is usually a chat dataset but doesn't need to be; it could just be tokens.\n    \"\"\"\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        raise NotImplementedError\n\n\n@chz.chz\nclass ChatDatasetBuilderCommonConfig:\n    \"\"\"\n    Config that all chat dataset builders have\n    Some specific datasets have additional options.\n    \"\"\"\n\n    model_name_for_tokenizer: str\n\n    renderer_name: str\n    max_length: int | None\n    batch_size: int\n    train_on_what: renderers.TrainOnWhat | None = None\n\n\n@chz.chz\nclass ChatDatasetBuilder(SupervisedDatasetBuilder):\n    \"\"\"\n    Builds a chat dataset, which is a dataset that uses a renderer to convert from\n    list-of-messages to tokens.\n    \"\"\"\n\n    common_config: ChatDatasetBuilderCommonConfig\n\n    def __call__(self) -> tuple[SupervisedDataset, SupervisedDataset | None]:\n        \"\"\"\n        Return a training dataset and optionally an evaluation dataset.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def tokenizer(self) -> Tokenizer:\n        return get_tokenizer(self.common_config.model_name_for_tokenizer)\n\n    @property\n    def renderer(self) -> renderers.Renderer:\n        return renderers.get_renderer(\n            self.common_config.renderer_name,\n            self.tokenizer,\n        )\n"
  },
  {
    "path": "tinker_cookbook/supervised/viz_sft_dataset.py",
    "content": "\"\"\"\nScript to visualize supervised datasets in the terminal.\n\"\"\"\n\nimport chz\n\nfrom tinker_cookbook import model_info\nfrom tinker_cookbook.renderers import TrainOnWhat\nfrom tinker_cookbook.supervised.types import (\n    ChatDatasetBuilderCommonConfig,\n    SupervisedDatasetBuilder,\n)\nfrom tinker_cookbook.tokenizer_utils import get_tokenizer\nfrom tinker_cookbook.utils.format_colorized import format_colorized\nfrom tinker_cookbook.utils.misc_utils import lookup_func\n\n\n@chz.chz\nclass Config:\n    model_name: str = \"meta-llama/Llama-3.1-8B\"  # just for tokenizer\n    dataset_path: str = \"Tulu3Builder\"\n    renderer_name: str | None = None\n    max_length: int | None = None\n    train_on_what: TrainOnWhat | None = None\n\n\ndef run(cfg: Config):\n    n_examples_total = 100\n    common_config = ChatDatasetBuilderCommonConfig(\n        model_name_for_tokenizer=cfg.model_name,\n        renderer_name=cfg.renderer_name or model_info.get_recommended_renderer_name(cfg.model_name),\n        max_length=cfg.max_length,\n        batch_size=n_examples_total,\n        train_on_what=cfg.train_on_what,\n    )\n    dataset_builder = lookup_func(\n        cfg.dataset_path, default_module=\"tinker_cookbook.recipes.chat_sl.chat_datasets\"\n    )(common_config=common_config)\n    assert isinstance(dataset_builder, SupervisedDatasetBuilder)\n    tokenizer = get_tokenizer(cfg.model_name)\n    train_dataset, _ = dataset_builder()\n    batch = train_dataset.get_batch(0)\n\n    for datum in batch:\n        int_tokens = list(datum.model_input.to_ints()) + [\n            datum.loss_fn_inputs[\"target_tokens\"].tolist()[-1]\n        ]\n        weights = [0.0] + datum.loss_fn_inputs[\"weights\"].tolist()\n        print(format_colorized(int_tokens, weights, tokenizer))\n        input(\"press enter\")\n\n\nif __name__ == \"__main__\":\n    chz.nested_entrypoint(run)\n"
  },
  {
    "path": "tinker_cookbook/third_party/__init__.py",
    "content": ""
  },
  {
    "path": "tinker_cookbook/third_party/litellm/README.md",
    "content": "# LiteLLM Integration\n\nA [LiteLLM](https://docs.litellm.ai/) custom provider that routes calls through Tinker's native `SamplingClient` for optimal sampling performance.\n\n## Why use this?\n\nIf you have an agent or application built on LiteLLM (or frameworks that use it, like LangChain, CrewAI, or AutoGen), this integration lets you:\n\n1. **Run your existing code against Tinker** without rewriting it to use the Tinker SDK directly\n2. **Get raw token IDs** from every request, which you can feed into Tinker's training APIs for supervised learning or RL\n\nTinker also offers an [OpenAI-compatible endpoint](/compatible-apis/openai), which works with LiteLLM out of the box. However, the native `SamplingClient` used by this integration provides better performance.\n\n## Setup\n\nThis integration requires `litellm` as an additional dependency. From the tinker-cookbook repo root:\n\n```bash\nuv pip install -e \".[litellm]\"\n```\n\nYou also need a `TINKER_API_KEY` — see [Getting an API key](https://tinker-docs.thinkingmachines.ai/install#getting-an-api-key).\n\n## Quick start\n\n```python\nfrom tinker_cookbook.third_party.litellm import register_litellm_provider\nimport litellm\n\n# Register once at startup\nregister_litellm_provider()\n\n# The \"tinker/\" prefix routes to this provider.\n# base_model is the Tinker model to sample from.\nresponse = await litellm.acompletion(\n    model=\"tinker/my-label\",\n    messages=[{\"role\": \"user\", \"content\": \"Hello!\"}],\n    base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n    temperature=0.7,\n    max_tokens=256,\n)\n\nprint(response.choices[0].message.content)\n```\n\n## How `model` and `base_model` work\n\nLiteLLM uses the `model` parameter to decide which provider handles the request. The `tinker/` prefix routes to this provider — everything after the prefix is an arbitrary label that appears in the response metadata (it does **not** select the model).\n\nThe actual model is determined by `base_model`, which is passed directly to Tinker's `ServiceClient.create_sampling_client(base_model=...)`. This must be a model name from Tinker's [model lineup](https://tinker-docs.thinkingmachines.ai/model-lineup), e.g.:\n\n- `Qwen/Qwen3-4B-Instruct-2507`\n- `meta-llama/Llama-3.1-8B-Instruct`\n- `moonshotai/Kimi-K2.5`\n\nYou can list available models with:\n\n```python\nimport tinker\nservice = tinker.ServiceClient()\nfor m in service.get_server_capabilities().supported_models:\n    print(m.model_name)\n```\n\n### Sampling from a fine-tuned checkpoint\n\nWhen you pass `base_model`, the provider creates a `SamplingClient` for that pretrained model. To sample from a **fine-tuned** checkpoint instead, use `set_client()` to inject a `SamplingClient` created with a `model_path`:\n\n```python\nimport tinker\n\nprovider = register_litellm_provider()\n\n# Create a sampling client pointing at your fine-tuned checkpoint.\n# The model_path comes from training_client.save_weights_for_sampler().\nservice = tinker.ServiceClient()\nsampler = service.create_sampling_client(\n    model_path=\"tinker://<experiment-id>/sampler_weights/000080\"\n)\n\n# The provider reads the base model from the sampling client automatically\n# to resolve the correct renderer and tokenizer.\nprovider.set_client(sampler)\n\n# Now litellm calls will sample from your fine-tuned checkpoint.\n# base_model must still match so the provider finds the right client bundle.\nresponse = await litellm.acompletion(\n    model=\"tinker/my-finetuned\",\n    messages=[{\"role\": \"user\", \"content\": \"Hello!\"}],\n    base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n)\n```\n\nSee [Saving and loading weights](https://tinker-docs.thinkingmachines.ai/save-load) for how to obtain checkpoint paths.\n\n### Custom Tinker deployments\n\nFor private or non-default Tinker deployments, pass a pre-configured `ServiceClient`:\n\n```python\nimport tinker\n\nservice = tinker.ServiceClient(base_url=\"https://my-tinker.example.com\")\nprovider = register_litellm_provider(service_client=service)\n```\n\n## Accessing raw tokens for training\n\nThe key feature of this integration is token-level access for training workflows:\n\n```python\nresponse = await litellm.acompletion(\n    model=\"tinker/my-label\",\n    messages=messages,\n    base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n)\n\n# Raw token IDs are in provider_specific_fields\nfields = response.choices[0].message.provider_specific_fields\nprompt_token_ids = fields[\"prompt_token_ids\"]       # list[int]\ncompletion_token_ids = fields[\"completion_token_ids\"]  # list[int]\n\n# Use these directly with Tinker's training APIs\n```\n\n## Supported parameters\n\n| LiteLLM parameter | Description |\n|---|---|\n| `model` | Must start with `tinker/` to route to this provider. The rest is a label for the response metadata. |\n| `base_model` | **Required.** Tinker model name passed to `create_sampling_client()`. See [model lineup](https://tinker-docs.thinkingmachines.ai/model-lineup). |\n| `temperature` | Sampling temperature |\n| `max_tokens` / `max_completion_tokens` | Maximum tokens to generate |\n| `top_p` | Nucleus sampling parameter |\n| `top_k` | Top-k sampling parameter |\n| `stop` | Stop sequences (defaults to model's stop sequences) |\n| `tools` | OpenAI-format tool definitions |\n\n## Tool calling\n\nTool declarations are supported for models whose renderers implement `create_conversation_prefix_with_tools` (Qwen3, DeepSeek V3, Kimi K2/K2.5, GPT-OSS):\n\n```python\nresponse = await litellm.acompletion(\n    model=\"tinker/my-agent\",\n    messages=[{\"role\": \"user\", \"content\": \"What's the weather in SF?\"}],\n    base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n    tools=[{\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_weather\",\n            \"description\": \"Get weather for a city\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\"city\": {\"type\": \"string\"}},\n                \"required\": [\"city\"],\n            },\n        },\n    }],\n)\n```\n\n## Sync and async\n\nBoth `litellm.completion()` and `litellm.acompletion()` are supported.\n\n## API reference\n\n### `register_litellm_provider(*, service_client=None)`\n\nRegister the Tinker provider with LiteLLM. Returns a provider instance.\n\n- **Idempotent** — safe to call multiple times; returns the same instance after the first call.\n- `service_client` (`tinker.ServiceClient | None`) — optional pre-configured client for custom deployments. If `None`, a default `ServiceClient` is created on first use. Ignored on subsequent calls.\n\n### `provider.set_client(sampling_client)`\n\nInject a custom `SamplingClient` into the provider (e.g., for a fine-tuned checkpoint).\n\n- `sampling_client` (`tinker.SamplingClient`) — the client to use. The base model is read automatically via `sampling_client.get_base_model()` to resolve the correct renderer and tokenizer.\n- If a client bundle for that base model already exists, only the sampling client is replaced (renderer and tokenizer are reused).\n"
  },
  {
    "path": "tinker_cookbook/third_party/litellm/__init__.py",
    "content": "from tinker_cookbook.third_party.litellm.provider import register_litellm_provider\n\n__all__ = [\n    \"register_litellm_provider\",\n]\n"
  },
  {
    "path": "tinker_cookbook/third_party/litellm/provider.py",
    "content": "\"\"\"\nLiteLLM custom provider for Tinker sampling.\n\nEnables using Tinker's native SamplingClient through LiteLLM's unified interface,\ngiving optimal sampling performance while exposing raw token IDs for training.\n\nUsage::\n\n    from tinker_cookbook.third_party.litellm import register_litellm_provider\n    import litellm\n\n    register_litellm_provider()\n\n    response = await litellm.acompletion(\n        model=\"tinker/my-model\",\n        messages=[{\"role\": \"user\", \"content\": \"Hello!\"}],\n        base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n    )\n\n    # Access raw tokens for training\n    fields = response.choices[0].message.provider_specific_fields\n    prompt_tokens = fields[\"prompt_token_ids\"]\n    completion_tokens = fields[\"completion_token_ids\"]\n\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport time\nimport uuid\nfrom collections.abc import Callable\nfrom dataclasses import dataclass\nfrom typing import Any, Union\n\nimport httpx\nimport tinker\n\nfrom tinker_cookbook import renderers\nfrom tinker_cookbook.model_info import get_recommended_renderer_name\nfrom tinker_cookbook.renderers.base import ToolCall\nfrom tinker_cookbook.third_party.openai_compat import (\n    openai_messages_to_tinker,\n    openai_tools_to_tinker,\n)\nfrom tinker_cookbook.tokenizer_utils import Tokenizer, get_tokenizer\n\ntry:\n    from litellm.llms.custom_llm import CustomLLM\n    from litellm.types.utils import Choices, Message, ModelResponse, Usage\nexcept ImportError:\n    raise ImportError(\n        \"litellm is required for the Tinker LiteLLM integration. \"\n        \"Install it with: uv pip install -e '.[litellm]'\"\n    ) from None\n\n\n# ---------------------------------------------------------------------------\n# Internal helpers: sampling pipeline and response building\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass _SamplingResult:\n    \"\"\"Result of a Tinker sampling call with all data needed to build any response format.\"\"\"\n\n    prompt_token_ids: list[int]\n    completion_token_ids: list[int]\n    logprobs: list[float] | None\n    parsed_message: renderers.Message\n    parse_success: bool\n    model_name: str\n\n\ndef _prepare_messages_with_tools(\n    renderer: renderers.Renderer,\n    messages: list[renderers.Message],\n    tools: list[dict[str, Any]],\n) -> list[renderers.Message]:\n    \"\"\"Inject tool declarations into the message list via the renderer.\n\n    Extracts the system message (if any), passes it to\n    ``renderer.create_conversation_prefix_with_tools``, and prepends the\n    resulting prefix messages to the remaining conversation.\n    \"\"\"\n    tool_specs = openai_tools_to_tinker(tools)\n\n    # Split out system message if present\n    system_prompt = \"\"\n    remaining: list[renderers.Message]\n    if messages and messages[0][\"role\"] == \"system\":\n        content = messages[0].get(\"content\") or \"\"\n        system_prompt = content if isinstance(content, str) else \"\"\n        remaining = list(messages[1:])\n    else:\n        remaining = list(messages)\n\n    prefix = renderer.create_conversation_prefix_with_tools(tool_specs, system_prompt)\n    return prefix + remaining\n\n\nasync def _sample_chat_completion(\n    sampling_client: tinker.SamplingClient,\n    renderer: renderers.Renderer,\n    messages: list[dict[str, Any]],\n    *,\n    temperature: float = 1.0,\n    max_tokens: int = 128,\n    top_p: float = 1.0,\n    top_k: int = -1,\n    stop: list[str] | list[int] | None = None,\n    tools: list[dict[str, Any]] | None = None,\n    model_name: str = \"tinker\",\n) -> _SamplingResult:\n    \"\"\"Run the full render -> sample -> parse pipeline.\"\"\"\n    tinker_messages = openai_messages_to_tinker(messages)\n\n    if tools:\n        tinker_messages = _prepare_messages_with_tools(renderer, tinker_messages, tools)\n\n    model_input = renderer.build_generation_prompt(tinker_messages)\n    prompt_token_ids: list[int] = model_input.to_ints()\n\n    if stop is None:\n        stop = renderer.get_stop_sequences()\n\n    sample_response = await sampling_client.sample_async(\n        prompt=model_input,\n        num_samples=1,\n        sampling_params=tinker.SamplingParams(\n            temperature=temperature,\n            max_tokens=max_tokens,\n            top_p=top_p,\n            top_k=top_k,\n            stop=stop,\n        ),\n    )\n\n    seq = sample_response.sequences[0]\n    completion_token_ids: list[int] = seq.tokens\n    logprobs: list[float] | None = seq.logprobs\n\n    parsed_message, parse_success = renderer.parse_response(completion_token_ids)\n\n    return _SamplingResult(\n        prompt_token_ids=prompt_token_ids,\n        completion_token_ids=completion_token_ids,\n        logprobs=logprobs,\n        parsed_message=parsed_message,\n        parse_success=parse_success,\n        model_name=model_name,\n    )\n\n\ndef _sampling_result_to_chat_completion_dict(result: _SamplingResult) -> dict[str, Any]:\n    \"\"\"Convert a _SamplingResult to an OpenAI ChatCompletion-compatible dict.\"\"\"\n    content = result.parsed_message.get(\"content\", \"\")\n    if isinstance(content, list):\n        content = renderers.format_content_as_string(content)\n\n    # Build tool_calls list if present\n    tool_calls_out: list[dict[str, Any]] | None = None\n    raw_tool_calls: list[ToolCall] | None = result.parsed_message.get(\"tool_calls\")\n    if raw_tool_calls:\n        tool_calls_out = [\n            {\n                \"id\": tc.id or f\"call_{i}\",\n                \"type\": \"function\",\n                \"function\": {\"name\": tc.function.name, \"arguments\": tc.function.arguments},\n            }\n            for i, tc in enumerate(raw_tool_calls)\n        ]\n\n    if tool_calls_out:\n        finish_reason = \"tool_calls\"\n    elif result.parse_success:\n        finish_reason = \"stop\"\n    else:\n        finish_reason = \"length\"\n\n    message_dict: dict[str, Any] = {\n        \"role\": \"assistant\",\n        \"content\": content or None,\n    }\n    if tool_calls_out:\n        message_dict[\"tool_calls\"] = tool_calls_out\n\n    return {\n        \"id\": f\"chatcmpl-tinker-{uuid.uuid4().hex[:12]}\",\n        \"object\": \"chat.completion\",\n        \"created\": int(time.time()),\n        \"model\": result.model_name,\n        \"choices\": [\n            {\n                \"index\": 0,\n                \"message\": message_dict,\n                \"finish_reason\": finish_reason,\n            }\n        ],\n        \"usage\": {\n            \"prompt_tokens\": len(result.prompt_token_ids),\n            \"completion_tokens\": len(result.completion_token_ids),\n            \"total_tokens\": len(result.prompt_token_ids) + len(result.completion_token_ids),\n        },\n    }\n\n\ndef _extract_sampling_params(optional_params: dict[str, Any]) -> dict[str, Any]:\n    \"\"\"Extract Tinker-compatible sampling parameters from LiteLLM optional_params.\"\"\"\n    params: dict[str, Any] = {}\n    if \"temperature\" in optional_params:\n        params[\"temperature\"] = float(optional_params[\"temperature\"])\n    if \"max_tokens\" in optional_params:\n        params[\"max_tokens\"] = int(optional_params[\"max_tokens\"])\n    elif \"max_completion_tokens\" in optional_params:\n        params[\"max_tokens\"] = int(optional_params[\"max_completion_tokens\"])\n    if \"top_p\" in optional_params:\n        params[\"top_p\"] = float(optional_params[\"top_p\"])\n    if \"top_k\" in optional_params:\n        params[\"top_k\"] = int(optional_params[\"top_k\"])\n    if \"stop\" in optional_params:\n        params[\"stop\"] = optional_params[\"stop\"]\n    return params\n\n\ndef _build_model_response(\n    result: _SamplingResult,\n    model_response: ModelResponse,\n) -> ModelResponse:\n    \"\"\"Populate a LiteLLM ModelResponse from a _SamplingResult.\"\"\"\n    completion_dict = _sampling_result_to_chat_completion_dict(result)\n\n    choice_data = completion_dict[\"choices\"][0]\n    message_data = choice_data[\"message\"]\n\n    model_response.choices = [\n        Choices(\n            finish_reason=choice_data[\"finish_reason\"],\n            index=0,\n            message=Message(\n                content=message_data.get(\"content\"),\n                role=\"assistant\",\n                tool_calls=message_data.get(\"tool_calls\"),\n                provider_specific_fields={\n                    \"prompt_token_ids\": result.prompt_token_ids,\n                    \"completion_token_ids\": result.completion_token_ids,\n                },\n            ),\n        )\n    ]\n\n    usage_data = completion_dict[\"usage\"]\n    model_response.usage = Usage(  # type: ignore[assignment]\n        prompt_tokens=usage_data[\"prompt_tokens\"],\n        completion_tokens=usage_data[\"completion_tokens\"],\n        total_tokens=usage_data[\"total_tokens\"],\n    )\n    model_response.model = result.model_name\n\n    return model_response\n\n\ndef _map_tinker_error(exc: Exception) -> Exception:\n    \"\"\"Map Tinker SDK exceptions to LiteLLM-compatible errors.\"\"\"\n    import litellm.exceptions\n\n    if isinstance(exc, tinker.AuthenticationError):\n        return litellm.exceptions.AuthenticationError(\n            message=str(exc),\n            llm_provider=\"tinker\",\n            model=\"\",\n        )\n    if isinstance(exc, tinker.RateLimitError):\n        return litellm.exceptions.RateLimitError(\n            message=str(exc),\n            llm_provider=\"tinker\",\n            model=\"\",\n        )\n    if isinstance(exc, tinker.APITimeoutError):\n        return litellm.exceptions.Timeout(\n            message=str(exc),\n            llm_provider=\"tinker\",\n            model=\"\",\n        )\n    if isinstance(exc, tinker.APIConnectionError):\n        return litellm.exceptions.APIConnectionError(\n            message=str(exc),\n            llm_provider=\"tinker\",\n            model=\"\",\n        )\n    if isinstance(exc, tinker.BadRequestError):\n        return litellm.exceptions.BadRequestError(\n            message=str(exc),\n            llm_provider=\"tinker\",\n            model=\"\",\n        )\n    # Fallback: re-raise as-is\n    return exc\n\n\n# ---------------------------------------------------------------------------\n# Client bundle and provider\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass _ClientBundle:\n    \"\"\"Cached group of objects needed to sample from a specific model.\"\"\"\n\n    sampling_client: tinker.SamplingClient\n    renderer: renderers.Renderer\n    tokenizer: Tokenizer\n    base_model: str\n\n\nclass TinkerLiteLLMProvider(CustomLLM):\n    \"\"\"LiteLLM custom provider that routes calls through Tinker's native SamplingClient.\"\"\"\n\n    def __init__(\n        self,\n        service_client: tinker.ServiceClient | None = None,\n    ) -> None:\n        super().__init__()\n        self._clients: dict[str, _ClientBundle] = {}\n        self._service_client = service_client\n\n    def _get_service_client(self) -> tinker.ServiceClient:\n        if self._service_client is None:\n            self._service_client = tinker.ServiceClient()\n        return self._service_client\n\n    def _get_or_create_client(self, base_model: str) -> _ClientBundle:\n        \"\"\"Get or lazily create a client bundle for the given base model.\"\"\"\n        if base_model not in self._clients:\n            tokenizer = get_tokenizer(base_model)\n            renderer_name = get_recommended_renderer_name(base_model)\n            renderer = renderers.get_renderer(renderer_name, tokenizer)\n            sampling_client = self._get_service_client().create_sampling_client(\n                base_model=base_model\n            )\n            self._clients[base_model] = _ClientBundle(\n                sampling_client=sampling_client,\n                renderer=renderer,\n                tokenizer=tokenizer,\n                base_model=base_model,\n            )\n        return self._clients[base_model]\n\n    def set_client(\n        self,\n        sampling_client: tinker.SamplingClient,\n    ) -> None:\n        \"\"\"Inject a custom SamplingClient (e.g., for a fine-tuned checkpoint).\n\n        The base model is read from the client via ``get_base_model()``,\n        and used to resolve the correct renderer and tokenizer. If a bundle\n        for that base model already exists, only the sampling client is replaced.\n        \"\"\"\n        base_model = sampling_client.get_base_model()\n        if base_model in self._clients:\n            self._clients[base_model].sampling_client = sampling_client\n        else:\n            tokenizer = get_tokenizer(base_model)\n            renderer_name = get_recommended_renderer_name(base_model)\n            renderer = renderers.get_renderer(renderer_name, tokenizer)\n            self._clients[base_model] = _ClientBundle(\n                sampling_client=sampling_client,\n                renderer=renderer,\n                tokenizer=tokenizer,\n                base_model=base_model,\n            )\n\n    async def acompletion(\n        self,\n        model: str,\n        messages: list,\n        api_base: str,\n        custom_prompt_dict: dict,\n        model_response: ModelResponse,\n        print_verbose: Callable,\n        encoding,\n        api_key,\n        logging_obj,\n        optional_params: dict,\n        acompletion=None,\n        litellm_params=None,\n        logger_fn=None,\n        headers={},  # noqa: B006\n        timeout: Union[float, httpx.Timeout] | None = None,\n        client=None,\n    ) -> ModelResponse:\n        base_model: str = (litellm_params or {}).get(\"base_model\", \"\")\n        if not base_model:\n            raise ValueError(\n                \"base_model is required for the Tinker provider. \"\n                \"Pass it as: litellm.acompletion(..., base_model='Qwen/Qwen3-4B-Instruct-2507')\"\n            )\n\n        bundle = self._get_or_create_client(base_model)\n        sampling_params = _extract_sampling_params(optional_params)\n\n        try:\n            result = await _sample_chat_completion(\n                sampling_client=bundle.sampling_client,\n                renderer=bundle.renderer,\n                messages=messages,\n                tools=optional_params.get(\"tools\"),\n                model_name=model,\n                **sampling_params,\n            )\n        except tinker.TinkerError as exc:\n            raise _map_tinker_error(exc) from exc\n\n        return _build_model_response(result, model_response)\n\n    def completion(\n        self,\n        model: str,\n        messages: list,\n        api_base: str,\n        custom_prompt_dict: dict,\n        model_response: ModelResponse,\n        print_verbose: Callable,\n        encoding,\n        api_key,\n        logging_obj,\n        optional_params: dict,\n        acompletion=None,\n        litellm_params=None,\n        logger_fn=None,\n        headers={},  # noqa: B006\n        timeout: Union[float, httpx.Timeout] | None = None,\n        client=None,\n    ) -> ModelResponse:\n        base_model: str = (litellm_params or {}).get(\"base_model\", \"\")\n        if not base_model:\n            raise ValueError(\n                \"base_model is required for the Tinker provider. \"\n                \"Pass it as: litellm.completion(..., base_model='Qwen/Qwen3-4B-Instruct-2507')\"\n            )\n\n        bundle = self._get_or_create_client(base_model)\n        sampling_params = _extract_sampling_params(optional_params)\n\n        coro = _sample_chat_completion(\n            sampling_client=bundle.sampling_client,\n            renderer=bundle.renderer,\n            messages=messages,\n            tools=optional_params.get(\"tools\"),\n            model_name=model,\n            **sampling_params,\n        )\n\n        try:\n            # If there's already a running event loop (e.g., Jupyter), use it.\n            # Otherwise, create a new one.\n            try:\n                loop = asyncio.get_running_loop()\n            except RuntimeError:\n                loop = None\n\n            if loop is not None and loop.is_running():\n                import concurrent.futures\n\n                with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:\n                    result = pool.submit(asyncio.run, coro).result()\n            else:\n                result = asyncio.run(coro)\n        except tinker.TinkerError as exc:\n            raise _map_tinker_error(exc) from exc\n\n        return _build_model_response(result, model_response)\n\n\n_registered_provider: TinkerLiteLLMProvider | None = None\n\n\ndef register_litellm_provider(\n    *,\n    service_client: tinker.ServiceClient | None = None,\n) -> TinkerLiteLLMProvider:\n    \"\"\"Register the Tinker provider with LiteLLM.\n\n    Safe to call multiple times — returns the same provider instance after\n    the first call. Use the returned instance to inject custom SamplingClients\n    via ``provider.set_client(sampling_client)``.\n\n    Args:\n        service_client: Optional pre-configured ``tinker.ServiceClient``.\n            Useful for custom deployments with a non-default ``base_url``.\n            If None, a default ``ServiceClient`` is created on first use.\n            Ignored on subsequent calls (singleton already exists).\n    \"\"\"\n    import litellm\n\n    global _registered_provider\n    if _registered_provider is not None:\n        return _registered_provider\n\n    provider = TinkerLiteLLMProvider(service_client=service_client)\n    litellm.custom_provider_map.append({\"provider\": \"tinker\", \"custom_handler\": provider})\n    _registered_provider = provider\n    return provider\n"
  },
  {
    "path": "tinker_cookbook/third_party/litellm/provider_test.py",
    "content": "\"\"\"Tests for the LiteLLM integration.\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Any\nfrom unittest.mock import AsyncMock, MagicMock, patch\n\nimport pytest\n\nfrom tinker_cookbook.renderers.base import ToolCall\nfrom tinker_cookbook.third_party.litellm.provider import (\n    _extract_sampling_params,\n    _prepare_messages_with_tools,\n    _sample_chat_completion,\n    _sampling_result_to_chat_completion_dict,\n    _SamplingResult,\n)\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass FakeSampledSequence:\n    tokens: list[int]\n    logprobs: list[float] | None\n    stop_reason: str = \"stop\"\n\n\n@dataclass\nclass FakeSampleResponse:\n    sequences: list[FakeSampledSequence]\n\n\ndef _make_sampling_result(\n    *,\n    prompt_tokens: list[int] | None = None,\n    completion_tokens: list[int] | None = None,\n    content: str = \"Hello!\",\n    parse_success: bool = True,\n    tool_calls: list[ToolCall] | None = None,\n) -> _SamplingResult:\n    msg: dict[str, Any] = {\"role\": \"assistant\", \"content\": content}\n    if tool_calls is not None:\n        msg[\"tool_calls\"] = tool_calls\n    return _SamplingResult(\n        prompt_token_ids=prompt_tokens or [1, 2, 3],\n        completion_token_ids=completion_tokens or [4, 5, 6],\n        logprobs=[0.1, 0.2, 0.3],\n        parsed_message=msg,  # type: ignore[arg-type]\n        parse_success=parse_success,\n        model_name=\"tinker/test-model\",\n    )\n\n\n# ---------------------------------------------------------------------------\n# _extract_sampling_params\n# ---------------------------------------------------------------------------\n\n\nclass TestExtractSamplingParams:\n    def test_all_params(self) -> None:\n        params = _extract_sampling_params(\n            {\n                \"temperature\": 0.5,\n                \"max_tokens\": 256,\n                \"top_p\": 0.9,\n                \"top_k\": 50,\n                \"stop\": [\"STOP\"],\n                \"irrelevant_param\": True,\n            }\n        )\n        assert params == {\n            \"temperature\": 0.5,\n            \"max_tokens\": 256,\n            \"top_p\": 0.9,\n            \"top_k\": 50,\n            \"stop\": [\"STOP\"],\n        }\n\n    def test_max_completion_tokens(self) -> None:\n        params = _extract_sampling_params({\"max_completion_tokens\": 128})\n        assert params == {\"max_tokens\": 128}\n\n    def test_empty(self) -> None:\n        assert _extract_sampling_params({}) == {}\n\n\n# ---------------------------------------------------------------------------\n# _prepare_messages_with_tools\n# ---------------------------------------------------------------------------\n\n\nclass TestPrepareMessagesWithTools:\n    def test_extracts_system_message(self) -> None:\n        renderer = MagicMock()\n        renderer.create_conversation_prefix_with_tools.return_value = [\n            {\"role\": \"system\", \"content\": \"You have tools: [search]. Also: Be helpful.\"}\n        ]\n\n        messages = [\n            {\"role\": \"system\", \"content\": \"Be helpful.\"},\n            {\"role\": \"user\", \"content\": \"Hi\"},\n        ]\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\"name\": \"search\", \"description\": \"Search\", \"parameters\": {}},\n            }\n        ]\n\n        result = _prepare_messages_with_tools(renderer, messages, tools)  # type: ignore[arg-type]\n\n        renderer.create_conversation_prefix_with_tools.assert_called_once()\n        args = renderer.create_conversation_prefix_with_tools.call_args\n        assert args[0][1] == \"Be helpful.\"  # system_prompt extracted\n        # User message comes after the prefix\n        assert result[-1][\"role\"] == \"user\"\n\n    def test_no_system_message(self) -> None:\n        renderer = MagicMock()\n        renderer.create_conversation_prefix_with_tools.return_value = [\n            {\"role\": \"system\", \"content\": \"Tools: [search]\"}\n        ]\n\n        messages = [{\"role\": \"user\", \"content\": \"Hi\"}]\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\"name\": \"search\", \"description\": \"Search\", \"parameters\": {}},\n            }\n        ]\n\n        _prepare_messages_with_tools(renderer, messages, tools)  # type: ignore[arg-type]\n\n        args = renderer.create_conversation_prefix_with_tools.call_args\n        assert args[0][1] == \"\"  # no system prompt\n\n\n# ---------------------------------------------------------------------------\n# _sampling_result_to_chat_completion_dict\n# ---------------------------------------------------------------------------\n\n\nclass TestSamplingResultToDict:\n    def test_basic_response(self) -> None:\n        result = _make_sampling_result(content=\"Hi there!\")\n        d = _sampling_result_to_chat_completion_dict(result)\n\n        assert d[\"object\"] == \"chat.completion\"\n        assert d[\"model\"] == \"tinker/test-model\"\n        assert len(d[\"choices\"]) == 1\n        assert d[\"choices\"][0][\"message\"][\"content\"] == \"Hi there!\"\n        assert d[\"choices\"][0][\"message\"][\"role\"] == \"assistant\"\n        assert d[\"choices\"][0][\"finish_reason\"] == \"stop\"\n        assert d[\"usage\"][\"prompt_tokens\"] == 3\n        assert d[\"usage\"][\"completion_tokens\"] == 3\n\n    def test_parse_failure_gives_length_finish(self) -> None:\n        result = _make_sampling_result(parse_success=False)\n        d = _sampling_result_to_chat_completion_dict(result)\n        assert d[\"choices\"][0][\"finish_reason\"] == \"length\"\n\n    def test_tool_calls_in_response(self) -> None:\n        tc = ToolCall(\n            function=ToolCall.FunctionBody(name=\"search\", arguments='{\"q\": \"test\"}'),\n            id=\"call_abc\",\n        )\n        result = _make_sampling_result(tool_calls=[tc])\n        d = _sampling_result_to_chat_completion_dict(result)\n\n        assert d[\"choices\"][0][\"finish_reason\"] == \"tool_calls\"\n        tool_calls = d[\"choices\"][0][\"message\"][\"tool_calls\"]\n        assert len(tool_calls) == 1\n        assert tool_calls[0][\"function\"][\"name\"] == \"search\"\n        assert tool_calls[0][\"id\"] == \"call_abc\"\n\n    def test_tool_call_without_id_gets_generated(self) -> None:\n        tc = ToolCall(\n            function=ToolCall.FunctionBody(name=\"search\", arguments=\"{}\"),\n            id=None,\n        )\n        result = _make_sampling_result(tool_calls=[tc])\n        d = _sampling_result_to_chat_completion_dict(result)\n        assert d[\"choices\"][0][\"message\"][\"tool_calls\"][0][\"id\"] == \"call_0\"\n\n    def test_list_content_formatted_as_string(self) -> None:\n        result = _make_sampling_result()\n        result.parsed_message[\"content\"] = [\n            {\"type\": \"text\", \"text\": \"Hello \"},\n            {\"type\": \"text\", \"text\": \"world!\"},\n        ]\n        d = _sampling_result_to_chat_completion_dict(result)\n        assert d[\"choices\"][0][\"message\"][\"content\"] == \"Hello \\nworld!\"\n\n\n# ---------------------------------------------------------------------------\n# _sample_chat_completion\n# ---------------------------------------------------------------------------\n\n\nclass TestSampleChatCompletion:\n    @pytest.mark.asyncio\n    async def test_basic_flow(self) -> None:\n        fake_response = FakeSampleResponse(\n            sequences=[FakeSampledSequence(tokens=[10, 20, 30], logprobs=[0.1, 0.2, 0.3])]\n        )\n        sampling_client = MagicMock()\n        sampling_client.sample_async = AsyncMock(return_value=fake_response)\n\n        renderer = MagicMock()\n        renderer.build_generation_prompt.return_value = MagicMock()\n        renderer.build_generation_prompt.return_value.to_ints.return_value = [1, 2, 3]\n        renderer.get_stop_sequences.return_value = [\"<|endoftext|>\"]\n        renderer.parse_response.return_value = (\n            {\"role\": \"assistant\", \"content\": \"response\"},\n            True,\n        )\n\n        result = await _sample_chat_completion(\n            sampling_client=sampling_client,\n            renderer=renderer,\n            messages=[{\"role\": \"user\", \"content\": \"Hi\"}],\n            temperature=0.5,\n            max_tokens=64,\n        )\n\n        assert result.prompt_token_ids == [1, 2, 3]\n        assert result.completion_token_ids == [10, 20, 30]\n        assert result.parse_success is True\n        assert result.parsed_message[\"content\"] == \"response\"\n\n        # Verify sampling params were passed correctly\n        call_kwargs = sampling_client.sample_async.call_args.kwargs\n        assert call_kwargs[\"sampling_params\"].temperature == 0.5\n        assert call_kwargs[\"sampling_params\"].max_tokens == 64\n\n    @pytest.mark.asyncio\n    async def test_with_tools(self) -> None:\n        fake_response = FakeSampleResponse(\n            sequences=[FakeSampledSequence(tokens=[10], logprobs=[0.1])]\n        )\n        sampling_client = MagicMock()\n        sampling_client.sample_async = AsyncMock(return_value=fake_response)\n\n        renderer = MagicMock()\n        renderer.build_generation_prompt.return_value = MagicMock()\n        renderer.build_generation_prompt.return_value.to_ints.return_value = [1]\n        renderer.get_stop_sequences.return_value = []\n        renderer.create_conversation_prefix_with_tools.return_value = [\n            {\"role\": \"system\", \"content\": \"Tools: [search]\"}\n        ]\n        renderer.parse_response.return_value = (\n            {\"role\": \"assistant\", \"content\": \"done\"},\n            True,\n        )\n\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\"name\": \"search\", \"description\": \"Search\", \"parameters\": {}},\n            }\n        ]\n        result = await _sample_chat_completion(\n            sampling_client=sampling_client,\n            renderer=renderer,\n            messages=[{\"role\": \"user\", \"content\": \"Hi\"}],\n            tools=tools,\n        )\n\n        renderer.create_conversation_prefix_with_tools.assert_called_once()\n        assert result.parse_success is True\n\n    @pytest.mark.asyncio\n    async def test_custom_stop_sequences(self) -> None:\n        fake_response = FakeSampleResponse(\n            sequences=[FakeSampledSequence(tokens=[10], logprobs=None)]\n        )\n        sampling_client = MagicMock()\n        sampling_client.sample_async = AsyncMock(return_value=fake_response)\n\n        renderer = MagicMock()\n        renderer.build_generation_prompt.return_value = MagicMock()\n        renderer.build_generation_prompt.return_value.to_ints.return_value = [1]\n        renderer.parse_response.return_value = (\n            {\"role\": \"assistant\", \"content\": \"ok\"},\n            True,\n        )\n\n        await _sample_chat_completion(\n            sampling_client=sampling_client,\n            renderer=renderer,\n            messages=[{\"role\": \"user\", \"content\": \"Hi\"}],\n            stop=[\"STOP\"],\n        )\n\n        call_kwargs = sampling_client.sample_async.call_args.kwargs\n        assert call_kwargs[\"sampling_params\"].stop == [\"STOP\"]\n        # get_stop_sequences should NOT be called when stop is explicit\n        renderer.get_stop_sequences.assert_not_called()\n\n\n# ---------------------------------------------------------------------------\n# LiteLLM provider\n# ---------------------------------------------------------------------------\n\n\nclass TestTinkerLiteLLMProvider:\n    def test_register_adds_to_provider_map(self) -> None:\n        import litellm\n\n        import tinker_cookbook.third_party.litellm.provider as provider_mod\n        from tinker_cookbook.third_party.litellm import register_litellm_provider\n\n        # Reset the singleton so we can test fresh registration\n        old_registered = provider_mod._registered_provider\n        provider_mod._registered_provider = None\n\n        provider = None\n        try:\n            original_len = len(litellm.custom_provider_map)\n            provider = register_litellm_provider()\n            assert len(litellm.custom_provider_map) == original_len + 1\n            entry = litellm.custom_provider_map[-1]\n            assert entry[\"provider\"] == \"tinker\"\n            assert entry[\"custom_handler\"] is provider\n\n            # Calling again returns the same instance without adding a duplicate\n            provider2 = register_litellm_provider()\n            assert provider2 is provider\n            assert len(litellm.custom_provider_map) == original_len + 1\n        finally:\n            # Clean up\n            if provider is not None:\n                litellm.custom_provider_map[:] = [\n                    e\n                    for e in litellm.custom_provider_map\n                    if e.get(\"custom_handler\") is not provider\n                ]\n            provider_mod._registered_provider = old_registered\n\n    def test_set_client_creates_bundle(self) -> None:\n        from tinker_cookbook.third_party.litellm.provider import TinkerLiteLLMProvider\n\n        provider = TinkerLiteLLMProvider()\n        mock_client = MagicMock()\n        mock_client.get_base_model.return_value = \"Qwen/Qwen3-8B\"\n\n        with (\n            patch(\"tinker_cookbook.third_party.litellm.provider.get_tokenizer\") as mock_get_tok,\n            patch(\n                \"tinker_cookbook.third_party.litellm.provider.get_recommended_renderer_name\",\n                return_value=\"qwen3\",\n            ),\n            patch(\"tinker_cookbook.third_party.litellm.provider.renderers.get_renderer\"),\n        ):\n            mock_get_tok.return_value = MagicMock()\n            provider.set_client(mock_client)\n\n        assert \"Qwen/Qwen3-8B\" in provider._clients\n        assert provider._clients[\"Qwen/Qwen3-8B\"].sampling_client is mock_client\n\n    def test_set_client_updates_existing_bundle(self) -> None:\n        from tinker_cookbook.third_party.litellm.provider import (\n            TinkerLiteLLMProvider,\n            _ClientBundle,\n        )\n\n        provider = TinkerLiteLLMProvider()\n        old_client = MagicMock()\n        new_client = MagicMock()\n        new_client.get_base_model.return_value = \"Qwen/Qwen3-8B\"\n\n        provider._clients[\"Qwen/Qwen3-8B\"] = _ClientBundle(\n            sampling_client=old_client,\n            renderer=MagicMock(),\n            tokenizer=MagicMock(),\n            base_model=\"Qwen/Qwen3-8B\",\n        )\n\n        provider.set_client(new_client)\n        assert provider._clients[\"Qwen/Qwen3-8B\"].sampling_client is new_client\n\n    @pytest.mark.asyncio\n    async def test_acompletion_requires_base_model(self) -> None:\n        from tinker_cookbook.third_party.litellm.provider import TinkerLiteLLMProvider\n\n        provider = TinkerLiteLLMProvider()\n        model_response = MagicMock()\n\n        with pytest.raises(ValueError, match=\"base_model is required\"):\n            await provider.acompletion(\n                model=\"tinker/test\",\n                messages=[],\n                api_base=\"\",\n                custom_prompt_dict={},\n                model_response=model_response,\n                print_verbose=print,\n                encoding=None,\n                api_key=None,\n                logging_obj=MagicMock(),\n                optional_params={},\n                litellm_params={},\n            )\n\n    @pytest.mark.asyncio\n    async def test_acompletion_basic(self) -> None:\n        from tinker_cookbook.third_party.litellm.provider import (\n            TinkerLiteLLMProvider,\n            _ClientBundle,\n        )\n\n        provider = TinkerLiteLLMProvider()\n\n        fake_response = FakeSampleResponse(\n            sequences=[FakeSampledSequence(tokens=[10, 20], logprobs=[0.1, 0.2])]\n        )\n        mock_sampling_client = MagicMock()\n        mock_sampling_client.sample_async = AsyncMock(return_value=fake_response)\n\n        mock_renderer = MagicMock()\n        mock_renderer.build_generation_prompt.return_value = MagicMock()\n        mock_renderer.build_generation_prompt.return_value.to_ints.return_value = [1, 2, 3]\n        mock_renderer.get_stop_sequences.return_value = [\"<|end|>\"]\n        mock_renderer.parse_response.return_value = (\n            {\"role\": \"assistant\", \"content\": \"Hello!\"},\n            True,\n        )\n\n        provider._clients[\"Qwen/Qwen3-8B\"] = _ClientBundle(\n            sampling_client=mock_sampling_client,\n            renderer=mock_renderer,\n            tokenizer=MagicMock(),\n            base_model=\"Qwen/Qwen3-8B\",\n        )\n\n        model_response = MagicMock()\n\n        result = await provider.acompletion(\n            model=\"tinker/my-model\",\n            messages=[{\"role\": \"user\", \"content\": \"Hi\"}],\n            api_base=\"\",\n            custom_prompt_dict={},\n            model_response=model_response,\n            print_verbose=print,\n            encoding=None,\n            api_key=None,\n            logging_obj=MagicMock(),\n            optional_params={\"temperature\": 0.7, \"max_tokens\": 64},\n            litellm_params={\"base_model\": \"Qwen/Qwen3-8B\"},\n        )\n\n        assert result is model_response\n        # Verify the response was populated\n        fields = result.choices[0].message.provider_specific_fields\n        assert fields is not None\n        assert fields[\"prompt_token_ids\"] == [1, 2, 3]\n        assert fields[\"completion_token_ids\"] == [10, 20]\n"
  },
  {
    "path": "tinker_cookbook/third_party/openai_compat.py",
    "content": "\"\"\"OpenAI format compatibility utilities for tinker-cookbook.\n\nStateless conversion between OpenAI API message/tool formats and tinker-cookbook's\ninternal Message/ToolSpec/ToolCall types.\n\nThe reverse direction (tinker -> OpenAI) is handled by ``Renderer.to_openai_message()``.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom typing import Any\n\nfrom tinker_cookbook.renderers.base import (\n    Message,\n    ToolCall,\n    ToolSpec,\n)\n\n\ndef openai_messages_to_tinker(messages: list[dict[str, Any]]) -> list[Message]:\n    \"\"\"Convert OpenAI/LiteLLM message dicts to tinker-cookbook Messages.\"\"\"\n    out: list[Message] = []\n    for msg in messages:\n        tinker_msg: Message = {\n            \"role\": msg[\"role\"],\n            \"content\": msg.get(\"content\") or \"\",\n        }\n        if \"name\" in msg:\n            tinker_msg[\"name\"] = msg[\"name\"]\n        if \"tool_call_id\" in msg:\n            tinker_msg[\"tool_call_id\"] = msg[\"tool_call_id\"]\n        if \"tool_calls\" in msg:\n            tinker_msg[\"tool_calls\"] = [ToolCall.model_validate(tc) for tc in msg[\"tool_calls\"]]\n        out.append(tinker_msg)\n    return out\n\n\ndef openai_tools_to_tinker(tools: list[dict[str, Any]]) -> list[ToolSpec]:\n    \"\"\"Convert OpenAI-format tool dicts to renderer ToolSpec.\"\"\"\n    out: list[ToolSpec] = []\n    for tool in tools:\n        if tool.get(\"type\") != \"function\":\n            continue\n        func = tool[\"function\"]\n        out.append(\n            ToolSpec(\n                name=func[\"name\"],\n                description=func.get(\"description\", \"\"),\n                parameters=func.get(\"parameters\", {}),\n            )\n        )\n    return out\n"
  },
  {
    "path": "tinker_cookbook/third_party/openai_compat_test.py",
    "content": "\"\"\"Tests for OpenAI format compatibility utilities.\"\"\"\n\nfrom __future__ import annotations\n\nfrom tinker_cookbook.renderers.base import ToolCall\nfrom tinker_cookbook.third_party.openai_compat import (\n    openai_messages_to_tinker,\n    openai_tools_to_tinker,\n)\n\n# ---------------------------------------------------------------------------\n# openai_messages_to_tinker\n# ---------------------------------------------------------------------------\n\n\nclass TestOpenAIMessagesToTinker:\n    def test_basic_messages(self) -> None:\n        messages = [\n            {\"role\": \"system\", \"content\": \"You are helpful.\"},\n            {\"role\": \"user\", \"content\": \"Hi\"},\n        ]\n        result = openai_messages_to_tinker(messages)\n        assert len(result) == 2\n        assert result[0][\"role\"] == \"system\"\n        assert result[0][\"content\"] == \"You are helpful.\"\n        assert result[1][\"role\"] == \"user\"\n\n    def test_message_with_tool_call_id(self) -> None:\n        messages = [\n            {\"role\": \"tool\", \"content\": \"result\", \"tool_call_id\": \"call_123\"},\n        ]\n        result = openai_messages_to_tinker(messages)\n        assert result[0].get(\"tool_call_id\") == \"call_123\"\n\n    def test_message_with_name(self) -> None:\n        messages = [{\"role\": \"user\", \"content\": \"hi\", \"name\": \"Alice\"}]\n        result = openai_messages_to_tinker(messages)\n        assert result[0].get(\"name\") == \"Alice\"\n\n    def test_message_with_tool_calls(self) -> None:\n        messages = [\n            {\n                \"role\": \"assistant\",\n                \"content\": None,\n                \"tool_calls\": [\n                    {\n                        \"type\": \"function\",\n                        \"id\": \"call_1\",\n                        \"function\": {\"name\": \"search\", \"arguments\": '{\"q\": \"test\"}'},\n                    }\n                ],\n            }\n        ]\n        result = openai_messages_to_tinker(messages)\n        tcs = result[0].get(\"tool_calls\")\n        assert tcs is not None\n        assert len(tcs) == 1\n        assert isinstance(tcs[0], ToolCall)\n        assert tcs[0].function.name == \"search\"\n\n    def test_none_content_becomes_empty_string(self) -> None:\n        messages = [{\"role\": \"assistant\", \"content\": None}]\n        result = openai_messages_to_tinker(messages)\n        assert result[0][\"content\"] == \"\"\n\n\n# ---------------------------------------------------------------------------\n# openai_tools_to_tinker\n# ---------------------------------------------------------------------------\n\n\nclass TestOpenAIToolsToTinker:\n    def test_basic_tool(self) -> None:\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": \"get_weather\",\n                    \"description\": \"Get weather for a city\",\n                    \"parameters\": {\n                        \"type\": \"object\",\n                        \"properties\": {\"city\": {\"type\": \"string\"}},\n                    },\n                },\n            }\n        ]\n        result = openai_tools_to_tinker(tools)\n        assert len(result) == 1\n        assert result[0][\"name\"] == \"get_weather\"\n        assert result[0][\"description\"] == \"Get weather for a city\"\n        assert \"properties\" in result[0][\"parameters\"]\n\n    def test_skips_non_function_tools(self) -> None:\n        tools = [\n            {\"type\": \"retrieval\"},\n            {\n                \"type\": \"function\",\n                \"function\": {\"name\": \"search\", \"description\": \"Search\", \"parameters\": {}},\n            },\n        ]\n        result = openai_tools_to_tinker(tools)\n        assert len(result) == 1\n        assert result[0][\"name\"] == \"search\"\n\n    def test_missing_description(self) -> None:\n        tools = [\n            {\n                \"type\": \"function\",\n                \"function\": {\"name\": \"noop\", \"parameters\": {}},\n            }\n        ]\n        result = openai_tools_to_tinker(tools)\n        assert result[0][\"description\"] == \"\"\n\n    def test_empty_tools(self) -> None:\n        assert openai_tools_to_tinker([]) == []\n"
  },
  {
    "path": "tinker_cookbook/tokenizer_utils.py",
    "content": "\"\"\"\nUtilities for working with tokenizers. Create new types to avoid needing to import AutoTokenizer and PreTrainedTokenizer.\n\n\nAvoid importing AutoTokenizer and PreTrainedTokenizer until runtime, because they're slow imports.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom collections.abc import Callable\nfrom functools import cache\nfrom typing import TYPE_CHECKING, Any, TypeAlias\n\nif TYPE_CHECKING:\n    # this import takes a few seconds, so avoid it on the module import when possible\n    from transformers import PreTrainedTokenizer\n\n    Tokenizer: TypeAlias = PreTrainedTokenizer\nelse:\n    # make it importable from other files as a type in runtime\n    Tokenizer: TypeAlias = Any\n\n# Global registry for custom tokenizer factories\n_CUSTOM_TOKENIZER_REGISTRY: dict[str, Callable[[], Tokenizer]] = {}\n\n\ndef register_tokenizer(\n    name: str,\n    factory: Callable[[], Tokenizer],\n) -> None:\n    \"\"\"Register a custom tokenizer factory.\n\n    Args:\n        name: The tokenizer name\n        factory: A callable that takes no arguments and returns a Tokenizer.\n\n    Example:\n        def my_tokenizer_factory():\n            return MyCustomTokenizer()\n\n        register_tokenizer(\"Foo/foo_tokenizer\", my_tokenizer_factory)\n    \"\"\"\n    _CUSTOM_TOKENIZER_REGISTRY[name] = factory\n\n\ndef get_registered_tokenizer_names() -> list[str]:\n    \"\"\"Return a list of all registered custom tokenizer names.\"\"\"\n    return list(_CUSTOM_TOKENIZER_REGISTRY.keys())\n\n\ndef is_tokenizer_registered(name: str) -> bool:\n    \"\"\"Check if a tokenizer name is registered.\"\"\"\n    return name in _CUSTOM_TOKENIZER_REGISTRY\n\n\ndef unregister_tokenizer(name: str) -> bool:\n    \"\"\"Unregister a custom tokenizer factory.\n\n    Args:\n        name: The tokenizer name to unregister.\n\n    Returns:\n        True if the tokenizer was unregistered, False if it wasn't registered.\n    \"\"\"\n    if name in _CUSTOM_TOKENIZER_REGISTRY:\n        del _CUSTOM_TOKENIZER_REGISTRY[name]\n        return True\n    return False\n\n\ndef get_tokenizer(model_name: str) -> Tokenizer:\n    \"\"\"Get a tokenizer by name.\n\n    Checks custom registry first, then falls back to HuggingFace AutoTokenizer.\n    \"\"\"\n    # Check custom registry first (not cached, factory handles caching if needed)\n    if (tokenizer := _CUSTOM_TOKENIZER_REGISTRY.get(model_name)) is not None:\n        return tokenizer()\n\n    return _get_hf_tokenizer(model_name)\n\n\n@cache\ndef _get_hf_tokenizer(model_name: str) -> Tokenizer:\n    from transformers.models.auto.tokenization_auto import AutoTokenizer\n\n    model_name = model_name.split(\":\")[0]\n\n    # Avoid gating of Llama 3 models:\n    if model_name.startswith(\"meta-llama/Llama-3\"):\n        model_name = \"thinkingmachineslabinc/meta-llama-3-instruct-tokenizer\"\n\n    kwargs: dict[str, Any] = {}\n    if os.environ.get(\"HF_TRUST_REMOTE_CODE\", \"\").lower() in (\"1\", \"true\", \"yes\"):\n        kwargs[\"trust_remote_code\"] = True\n\n    if model_name == \"moonshotai/Kimi-K2-Thinking\":\n        kwargs[\"trust_remote_code\"] = True\n        kwargs[\"revision\"] = \"a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55\"\n    elif model_name == \"moonshotai/Kimi-K2.5\":\n        kwargs[\"trust_remote_code\"] = True\n        kwargs[\"revision\"] = \"2426b45b6af0da48d0dcce71bbce6225e5c73adc\"\n\n    return AutoTokenizer.from_pretrained(model_name, use_fast=True, **kwargs)\n"
  },
  {
    "path": "tinker_cookbook/tokenizer_utils_test.py",
    "content": "from unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom tinker_cookbook.tokenizer_utils import _get_hf_tokenizer\n\n\n@pytest.fixture(autouse=True)\ndef _clear_cache() -> None:\n    \"\"\"Clear the lru_cache between tests so env var changes take effect.\"\"\"\n    _get_hf_tokenizer.cache_clear()\n\n\n@patch(\"transformers.models.auto.tokenization_auto.AutoTokenizer\")\ndef test_kimi_k2_thinking_trusts_remote_code_without_env(\n    mock_auto: MagicMock, monkeypatch: pytest.MonkeyPatch\n) -> None:\n    \"\"\"Hardcoded Kimi models should pass trust_remote_code=True without the env var.\"\"\"\n    monkeypatch.delenv(\"HF_TRUST_REMOTE_CODE\", raising=False)\n    _get_hf_tokenizer(\"moonshotai/Kimi-K2-Thinking\")\n    mock_auto.from_pretrained.assert_called_once_with(\n        \"moonshotai/Kimi-K2-Thinking\",\n        use_fast=True,\n        trust_remote_code=True,\n        revision=\"a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55\",\n    )\n\n\n@patch(\"transformers.models.auto.tokenization_auto.AutoTokenizer\")\ndef test_kimi_k25_trusts_remote_code_without_env(\n    mock_auto: MagicMock, monkeypatch: pytest.MonkeyPatch\n) -> None:\n    \"\"\"Hardcoded Kimi K2.5 should pass trust_remote_code=True without the env var.\"\"\"\n    monkeypatch.delenv(\"HF_TRUST_REMOTE_CODE\", raising=False)\n    _get_hf_tokenizer(\"moonshotai/Kimi-K2.5\")\n    mock_auto.from_pretrained.assert_called_once_with(\n        \"moonshotai/Kimi-K2.5\",\n        use_fast=True,\n        trust_remote_code=True,\n        revision=\"2426b45b6af0da48d0dcce71bbce6225e5c73adc\",\n    )\n\n\n@patch(\"transformers.models.auto.tokenization_auto.AutoTokenizer\")\ndef test_no_trust_remote_code_by_default(\n    mock_auto: MagicMock, monkeypatch: pytest.MonkeyPatch\n) -> None:\n    \"\"\"Without env var, generic models should NOT get trust_remote_code.\"\"\"\n    monkeypatch.delenv(\"HF_TRUST_REMOTE_CODE\", raising=False)\n    _get_hf_tokenizer(\"some-org/some-model\")\n    mock_auto.from_pretrained.assert_called_once_with(\n        \"some-org/some-model\",\n        use_fast=True,\n    )\n\n\n@pytest.mark.parametrize(\"env_value\", [\"1\", \"true\", \"TRUE\", \"yes\", \"Yes\"])\n@patch(\"transformers.models.auto.tokenization_auto.AutoTokenizer\")\ndef test_env_var_enables_trust_remote_code(\n    mock_auto: MagicMock, monkeypatch: pytest.MonkeyPatch, env_value: str\n) -> None:\n    \"\"\"HF_TRUST_REMOTE_CODE env var should enable trust_remote_code for any model.\"\"\"\n    monkeypatch.setenv(\"HF_TRUST_REMOTE_CODE\", env_value)\n    _get_hf_tokenizer(\"some-org/some-model\")\n    mock_auto.from_pretrained.assert_called_once_with(\n        \"some-org/some-model\",\n        use_fast=True,\n        trust_remote_code=True,\n    )\n\n\n@pytest.mark.parametrize(\"env_value\", [\"0\", \"false\", \"no\", \"\"])\n@patch(\"transformers.models.auto.tokenization_auto.AutoTokenizer\")\ndef test_env_var_falsy_values_do_not_enable(\n    mock_auto: MagicMock, monkeypatch: pytest.MonkeyPatch, env_value: str\n) -> None:\n    \"\"\"Falsy values for HF_TRUST_REMOTE_CODE should not enable trust_remote_code.\"\"\"\n    monkeypatch.setenv(\"HF_TRUST_REMOTE_CODE\", env_value)\n    _get_hf_tokenizer(\"some-org/some-model\")\n    mock_auto.from_pretrained.assert_called_once_with(\n        \"some-org/some-model\",\n        use_fast=True,\n    )\n"
  },
  {
    "path": "tinker_cookbook/tool_use/README.md",
    "content": "# Tool Use Library\n\n> **Note:** This library is currently experimental and may change without warning.\n\nA library for training tool-use agents with Tinker.\n\n## Overview\n\nThe `tool_use` library provides:\n\n- **`@tool` decorator** - Define tools from Python functions with automatic schema extraction\n- **`Tool` protocol** - Interface for implementing custom tools\n- **`AgentToolMessageEnv`** - RL environment for training tool-use agents\n\n## Quick Example\n\n```python\nfrom tinker_cookbook.tool_use import tool, simple_tool_result, build_agent_tool_env\n\n@tool\nasync def search(query: Annotated[str, \"Search query\"]) -> ToolResult:\n    \"\"\"Search for information.\"\"\"\n    results = await do_search(query)\n    return simple_tool_result(json.dumps(results))\n\nenv = build_agent_tool_env(\n    renderer=renderer,\n    tools=[search],\n    initial_messages=messages,\n    reward_fn=my_reward_fn,\n    max_turns=5,\n)\n```\n\n## Stateful Tools\n\nStateful tools, including tools that share state, can be constructed by adding the `@tool` decorator to class methods with instance state:\n\n```python\nclass MyTools:\n    def __init__(self, api_key: str):\n        self._api_key = api_key\n\n    @tool\n    async def search(self, query: Annotated[str, \"Query\"]) -> ToolResult:\n        \"\"\"Search using the configured API.\"\"\"\n        results = await search_api(query, self._api_key)\n        return simple_tool_result(json.dumps(results))\n\n    @tool\n    async def lookup(self, id: Annotated[str, \"Document ID\"]) -> ToolResult:\n        \"\"\"Look up a document by ID.\"\"\"\n        result = await lookup_api(id, self._api_key)\n        return simple_tool_result(json.dumps(result))\n\n# Usage - both tools share the same api_key\ntools_obj = MyTools(api_key=\"...\")\nenv = build_agent_tool_env(..., tools=[tools_obj.search, tools_obj.lookup])\n```\n\n## Tool Lifetimes\n\nThe lifetime of an instantiated tool can be controlled by where it's instantiated:\n\n| Instantiation Location | Lifetime |\n|------------------------|----------|\n| In environment construction | Per trajectory |\n| In environment group construction | Per task (shared across trajectories) |\n| In full dataset construction | Entire training run |\n\n**Per-trajectory** (fresh state each rollout):\n```python\n# A fresh tool is instantiated for each Env\nasync def make_envs(self) -> Sequence[Env]:\n    return [\n        build_agent_tool_env(tools=[CodeTool(self.task).run])\n        for _ in range(self.group_size)\n    ]\n```\n\n**Shared across trajectories** (stateless or shared client):\n```python\n# A single tool is instantiated, and shared across Envs\ndef __init__(self, task, chroma_tool: ChromaTool):\n    self.chroma_tool = chroma_tool  # Created once, reused\n\nasync def make_envs(self) -> Sequence[Env]:\n    return [\n        build_agent_tool_env(tools=[self.chroma_tool.search])  # Same instance\n        for _ in range(self.group_size)\n    ]\n```\n\n## Examples\n\nFor examples of using the tool-use library, see the the following:\n\n- [code_rl recipe](../recipes/code_rl/) - Code generation with python execution tool\n- [search_tool recipe](../recipes/search_tool/) - Multi-hop QA with search tool\n"
  },
  {
    "path": "tinker_cookbook/tool_use/__init__.py",
    "content": "\"\"\"Tool-use library.\"\"\"\n\nfrom tinker_cookbook.tool_use.agent_tool_message_env import (\n    AgentToolMessageEnv,\n    build_agent_tool_env,\n)\nfrom tinker_cookbook.tool_use.tools import (\n    FunctionTool,\n    error_tool_result,\n    handle_tool_call,\n    simple_tool_result,\n    tool,\n)\nfrom tinker_cookbook.tool_use.types import (\n    Tool,\n    ToolInput,\n    ToolResult,\n    ToolSpec,\n)\n\n__all__ = [\n    \"AgentToolMessageEnv\",\n    \"build_agent_tool_env\",\n    \"FunctionTool\",\n    \"Tool\",\n    \"ToolInput\",\n    \"ToolResult\",\n    \"ToolSpec\",\n    \"error_tool_result\",\n    \"handle_tool_call\",\n    \"simple_tool_result\",\n    \"tool\",\n]\n"
  },
  {
    "path": "tinker_cookbook/tool_use/agent_tool_message_env.py",
    "content": "\"\"\"Tool-using agent environment.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nfrom collections.abc import Awaitable, Callable\nfrom dataclasses import dataclass, field\n\nfrom tinker_cookbook.renderers import Renderer\nfrom tinker_cookbook.renderers.base import Message, ToolCall, get_text_content\nfrom tinker_cookbook.rl import types\nfrom tinker_cookbook.rl.message_env import EnvFromMessageEnv, MessageEnv, MessageStepResult\nfrom tinker_cookbook.tool_use.tools import handle_tool_call\nfrom tinker_cookbook.tool_use.types import Tool\n\nRewardResult = tuple[float, dict[str, float]]\nRewardFn = Callable[[list[Message]], Awaitable[RewardResult]]\n# TODO(tyler): Consider supporting stateful tools that need to grade rollouts based on\n# information not contained in the message history (e.g., internal tool state that changes\n# during execution).\n\n\n@dataclass\nclass AgentToolMessageEnv(MessageEnv):\n    \"\"\"Generic tool-use MessageEnv for agents.\"\"\"\n\n    tools: list[Tool]\n    initial_messages: list[Message]\n    max_turns: int\n    reward_fn: RewardFn\n    history: list[Message] = field(default_factory=list)\n\n    _turn_count: int = 0\n    _tool_dict: dict[str, Tool] = field(default_factory=dict, init=False)\n    _should_stop: bool = field(default=False, init=False)\n\n    def __post_init__(self) -> None:\n        self._tool_dict = {t.name: t for t in self.tools}\n\n    async def initial_observation(self) -> list[Message]:\n        if not self.history:\n            self.history = list(self.initial_messages)\n        return self.history\n\n    async def _handle_tool_calls(self, tool_calls: list[ToolCall]) -> list[Message]:\n        \"\"\"Execute tool calls and append results to history.\n\n        Note: Tool metrics are not accumulated in the message history.\n        Only messages and should_stop are used from ToolResult.\n        \"\"\"\n        tool_results = await asyncio.gather(\n            *[handle_tool_call(self._tool_dict, tc) for tc in tool_calls]\n        )\n\n        all_messages: list[Message] = []\n\n        for tool_result in tool_results:\n            # Append messages to history\n            for msg in tool_result.messages:\n                self.history.append(msg)\n                all_messages.append(msg)\n\n            # Check if any tool signals to stop\n            if tool_result.should_stop:\n                self._should_stop = True\n\n        return all_messages\n\n    async def step(self, message: Message) -> MessageStepResult:\n        \"\"\"Execute any tools and return next messages.\n\n        The episode ends when:\n        - no tool calls in message (model decided to stop)\n        - a tool returns should_stop=True\n        - max_turns reached\n\n        reward_fn is called once at episode end to grade the full trajectory.\n        \"\"\"\n        self._turn_count += 1\n        metrics: dict[str, float] = {}\n        logs: types.Logs = {}\n\n        # Append the message to history\n        self.history.append(message)\n\n        # Log assistant content (handles both str and multimodal content)\n        assistant_text = get_text_content(message)\n        if assistant_text:\n            logs[\"assistant_content\"] = assistant_text\n\n        # Extract and execute tool calls if present\n        tool_calls: list[ToolCall] = list(message.get(\"tool_calls\") or [])\n        if tool_calls:\n            for i, tc in enumerate(tool_calls):\n                logs[f\"tool_call_{i}\"] = f\"{tc.function.name}({tc.function.arguments})\"\n\n            tool_result_messages = await self._handle_tool_calls(tool_calls)\n\n            for i, msg in enumerate(tool_result_messages):\n                logs[f\"tool_result_{i}\"] = get_text_content(msg)\n\n        # Determine if episode is done\n        no_tool_calls = len(tool_calls) == 0\n        max_turns_reached = self._turn_count >= self.max_turns\n        done = no_tool_calls or max_turns_reached or self._should_stop\n\n        if max_turns_reached and not no_tool_calls:\n            metrics[\"max_turns\"] = 1.0\n        if self._should_stop:\n            metrics[\"tool_stopped\"] = 1.0\n\n        reward = 0.0\n        if done:\n            reward, reward_metrics = await self.reward_fn(self.history)\n            metrics.update(reward_metrics)\n\n        return MessageStepResult(\n            reward=reward,\n            episode_done=done,\n            next_messages=self.history,\n            metrics=metrics,\n            logs=logs,\n        )\n\n\ndef build_agent_tool_env(\n    renderer: Renderer,\n    tools: list[Tool],\n    initial_messages: list[Message],\n    reward_fn: RewardFn,\n    *,\n    max_turns: int = 5,\n    failed_parse_reward: float = -0.1,\n    max_trajectory_tokens: int | None = None,\n) -> EnvFromMessageEnv:\n    \"\"\"Convenience method to build an EnvFromMessageEnv for tool-using agents.\n\n    Args:\n        renderer: The renderer for tokenizing messages.\n        tools: List of tools the agent can call (must implement Tool protocol).\n        initial_messages: Initial conversation history (system prompt, user message, etc.).\n        reward_fn: Function that grades a completed episode. Takes the full message\n            history and returns (reward, metrics). Called once at episode end.\n        max_turns: Maximum turns before episode ends.\n        failed_parse_reward: Reward when model output fails to parse.\n        max_trajectory_tokens: Maximum tokens in trajectory before terminating episode.\n\n    Returns:\n        An EnvFromMessageEnv ready for RL training.\n    \"\"\"\n    msg_env = AgentToolMessageEnv(\n        tools=tools,\n        initial_messages=initial_messages,\n        max_turns=max_turns,\n        reward_fn=reward_fn,\n    )\n    return EnvFromMessageEnv(\n        renderer=renderer,\n        message_env=msg_env,\n        failed_parse_reward=failed_parse_reward,\n        max_trajectory_tokens=max_trajectory_tokens,\n    )\n"
  },
  {
    "path": "tinker_cookbook/tool_use/agent_tool_message_env_test.py",
    "content": "\"\"\"Tests for AgentToolMessageEnv log population.\"\"\"\n\nimport asyncio\nfrom typing import Any\n\nfrom tinker_cookbook.renderers.base import Message, ToolCall, ToolSpec\nfrom tinker_cookbook.tool_use.agent_tool_message_env import AgentToolMessageEnv\nfrom tinker_cookbook.tool_use.tools import simple_tool_result\nfrom tinker_cookbook.tool_use.types import ToolInput, ToolResult\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\nasync def _noop_reward(history: list[Message]) -> tuple[float, dict[str, float]]:\n    return 1.0, {}\n\n\nclass StubTool:\n    \"\"\"Minimal Tool implementation for testing.\"\"\"\n\n    def __init__(self, name: str, response: str, should_stop: bool = False):\n        self._name = name\n        self._response = response\n        self._should_stop = should_stop\n\n    @property\n    def name(self) -> str:\n        return self._name\n\n    @property\n    def description(self) -> str:\n        return f\"Stub tool: {self._name}\"\n\n    @property\n    def parameters_schema(self) -> dict[str, Any]:\n        return {\"type\": \"object\", \"properties\": {}}\n\n    async def run(self, input: ToolInput) -> ToolResult:\n        return simple_tool_result(\n            self._response,\n            call_id=input.call_id or \"\",\n            name=self._name,\n            should_stop=self._should_stop,\n        )\n\n    def to_spec(self) -> ToolSpec:\n        return {\n            \"name\": self._name,\n            \"description\": self.description,\n            \"parameters\": self.parameters_schema,\n        }\n\n\ndef _make_tool_call(name: str, arguments: str = \"{}\", call_id: str = \"call_1\") -> ToolCall:\n    return ToolCall(id=call_id, function=ToolCall.FunctionBody(name=name, arguments=arguments))\n\n\n# ---------------------------------------------------------------------------\n# Tests\n# ---------------------------------------------------------------------------\n\n\nclass TestStepLogs:\n    \"\"\"AgentToolMessageEnv.step() should populate logs with diagnostic info.\"\"\"\n\n    def test_logs_assistant_content(self):\n        \"\"\"Logs include assistant_content when message has text content.\"\"\"\n        env = AgentToolMessageEnv(\n            tools=[],\n            initial_messages=[{\"role\": \"user\", \"content\": \"hi\"}],\n            max_turns=5,\n            reward_fn=_noop_reward,\n        )\n        asyncio.run(env.initial_observation())\n\n        result = asyncio.run(env.step({\"role\": \"assistant\", \"content\": \"Hello world\"}))\n\n        assert result.logs[\"assistant_content\"] == \"Hello world\"\n\n    def test_logs_empty_when_no_content(self):\n        \"\"\"Logs omit assistant_content when message has empty content.\"\"\"\n        env = AgentToolMessageEnv(\n            tools=[],\n            initial_messages=[{\"role\": \"user\", \"content\": \"hi\"}],\n            max_turns=5,\n            reward_fn=_noop_reward,\n        )\n        asyncio.run(env.initial_observation())\n\n        result = asyncio.run(env.step({\"role\": \"assistant\", \"content\": \"\"}))\n\n        assert \"assistant_content\" not in result.logs\n\n    def test_logs_multimodal_content(self):\n        \"\"\"Logs extract text from multimodal content via get_text_content.\"\"\"\n        env = AgentToolMessageEnv(\n            tools=[],\n            initial_messages=[{\"role\": \"user\", \"content\": \"hi\"}],\n            max_turns=5,\n            reward_fn=_noop_reward,\n        )\n        asyncio.run(env.initial_observation())\n\n        result = asyncio.run(\n            env.step(\n                {\n                    \"role\": \"assistant\",\n                    \"content\": [{\"type\": \"text\", \"text\": \"extracted text\"}],\n                }\n            )\n        )\n\n        assert result.logs[\"assistant_content\"] == \"extracted text\"\n\n    def test_logs_tool_calls_and_results(self):\n        \"\"\"Logs include tool call names/args and tool result content.\"\"\"\n        search_tool = StubTool(\"search\", '{\"results\": [\"a\", \"b\"]}')\n        env = AgentToolMessageEnv(\n            tools=[search_tool],\n            initial_messages=[{\"role\": \"user\", \"content\": \"find stuff\"}],\n            max_turns=5,\n            reward_fn=_noop_reward,\n        )\n        asyncio.run(env.initial_observation())\n\n        tc = _make_tool_call(\"search\", '{\"query\": \"weather\"}')\n        result = asyncio.run(\n            env.step({\"role\": \"assistant\", \"content\": \"Let me search.\", \"tool_calls\": [tc]})\n        )\n\n        assert result.logs[\"assistant_content\"] == \"Let me search.\"\n        assert result.logs[\"tool_call_0\"] == 'search({\"query\": \"weather\"})'\n        assert result.logs[\"tool_result_0\"] == '{\"results\": [\"a\", \"b\"]}'\n\n    def test_logs_multiple_tool_calls(self):\n        \"\"\"Logs index multiple tool calls and results separately.\"\"\"\n        search_tool = StubTool(\"search\", \"search result\")\n        calc_tool = StubTool(\"calc\", \"42\")\n        env = AgentToolMessageEnv(\n            tools=[search_tool, calc_tool],\n            initial_messages=[{\"role\": \"user\", \"content\": \"hi\"}],\n            max_turns=5,\n            reward_fn=_noop_reward,\n        )\n        asyncio.run(env.initial_observation())\n\n        tc1 = _make_tool_call(\"search\", '{\"q\": \"x\"}', call_id=\"call_1\")\n        tc2 = _make_tool_call(\"calc\", '{\"expr\": \"1+1\"}', call_id=\"call_2\")\n        result = asyncio.run(\n            env.step({\"role\": \"assistant\", \"content\": \"Doing both.\", \"tool_calls\": [tc1, tc2]})\n        )\n\n        assert result.logs[\"tool_call_0\"] == 'search({\"q\": \"x\"})'\n        assert result.logs[\"tool_call_1\"] == 'calc({\"expr\": \"1+1\"})'\n        assert result.logs[\"tool_result_0\"] == \"search result\"\n        assert result.logs[\"tool_result_1\"] == \"42\"\n\n    def test_logs_no_tool_calls(self):\n        \"\"\"When there are no tool calls, only assistant_content is logged.\"\"\"\n        env = AgentToolMessageEnv(\n            tools=[],\n            initial_messages=[{\"role\": \"user\", \"content\": \"hi\"}],\n            max_turns=5,\n            reward_fn=_noop_reward,\n        )\n        asyncio.run(env.initial_observation())\n\n        result = asyncio.run(env.step({\"role\": \"assistant\", \"content\": \"Just text.\"}))\n\n        assert result.logs == {\"assistant_content\": \"Just text.\"}\n        assert \"tool_call_0\" not in result.logs\n        assert \"tool_result_0\" not in result.logs\n"
  },
  {
    "path": "tinker_cookbook/tool_use/tools.py",
    "content": "\"\"\"Tool-use library for LLM agents.\"\"\"\n\nfrom __future__ import annotations\n\nimport asyncio\nimport inspect\nimport json\nfrom collections.abc import Callable\nfrom typing import (\n    Annotated,\n    Any,\n    get_args,\n    get_origin,\n    get_type_hints,\n)\n\nfrom pydantic import BaseModel, Field, create_model\nfrom pydantic.fields import FieldInfo\nfrom pydantic_core import PydanticUndefined\n\nfrom tinker_cookbook.renderers.base import ToolCall, ToolSpec\nfrom tinker_cookbook.tool_use.types import Tool, ToolInput, ToolResult\n\n\ndef simple_tool_result(\n    content: str,\n    *,\n    call_id: str = \"\",\n    name: str = \"\",\n    should_stop: bool = False,\n    metrics: dict[str, float] | None = None,\n    metadata: dict[str, Any] | None = None,\n) -> ToolResult:\n    \"\"\"Helper function to create a simple ToolResult from a content string.\n\n    Args:\n        content: The content to return to the model.\n        call_id: The tool call ID (usually passed from ToolInput).\n        name: The tool name (usually self.name in a tool method).\n        should_stop: Whether to stop the episode after this tool call.\n        metrics: Optional metrics dict (e.g., {\"latency\": 0.5, \"count\": 1}).\n        metadata: Optional metadata dict for debugging.\n\n    Returns:\n        A ToolResult with the given content and options.\n\n    Example:\n        @tool\n        async def search(query: str) -> ToolResult:\n            results = await do_search(query)\n            return simple_tool_result(\n                json.dumps(results),\n                metrics={\"result_count\": len(results)}\n            )\n    \"\"\"\n    return ToolResult(\n        messages=[\n            {\n                \"role\": \"tool\",\n                \"content\": content,\n                \"tool_call_id\": call_id,\n                \"name\": name,\n            }\n        ],\n        should_stop=should_stop,\n        metrics=metrics or {},\n        metadata=metadata or {},\n    )\n\n\ndef error_tool_result(\n    error_msg: str,\n    *,\n    call_id: str = \"\",\n    name: str = \"\",\n    error_type: str = \"execution_error\",\n    should_stop: bool = False,\n) -> ToolResult:\n    \"\"\"Helper function to create an error ToolResult.\n\n    Args:\n        error_msg: The error message to return.\n        call_id: The tool call ID (usually from ToolInput).\n        name: The tool name.\n        error_type: Error category for metadata (e.g., \"validation_failed\").\n        should_stop: Whether to stop the episode after this error.\n\n    Returns:\n        A ToolResult with error message and metadata.\n\n    Example:\n        except Exception as e:\n            return error_tool_result(\n                f\"Parameter validation failed: {e}\",\n                call_id=input.call_id or \"\",\n                name=self.name,\n                error_type=\"validation_failed\"\n            )\n    \"\"\"\n    return ToolResult(\n        messages=[\n            {\n                \"role\": \"tool\",\n                \"content\": json.dumps({\"error\": error_msg}),\n                \"tool_call_id\": call_id,\n                \"name\": name,\n            }\n        ],\n        should_stop=should_stop,\n        metrics={},\n        metadata={\"error\": error_type},\n    )\n\n\ndef _extract_annotated_info(annotation: Any) -> tuple[Any, FieldInfo | None, str | None]:\n    \"\"\"\n    Extract the base type, FieldInfo, and description from an Annotated type.\n\n    This is used by the @tool decorator to extract info about the tool's parameters.\n    \"\"\"\n    if get_origin(annotation) is not Annotated:\n        return annotation, None, None\n\n    args = get_args(annotation)\n    base_type = args[0]\n    field_info = None\n    description = None\n\n    for meta in args[1:]:\n        if isinstance(meta, str) and description is None:\n            description = meta\n        elif isinstance(meta, FieldInfo):\n            field_info = meta\n            if meta.description and description is None:\n                description = meta.description\n\n    return base_type, field_info, description\n\n\nclass FunctionTool:\n    \"\"\"\n    A tool created from a decorated function or method.\n\n    Implements the Tool protocol. Used internally by the @tool decorator.\n    \"\"\"\n\n    def __init__(self, fn: Callable[..., Any]):\n        self._fn = fn\n        self._instance: Any = None  # Will be set when accessed as descriptor\n        self._name = fn.__name__\n        self._description = fn.__doc__ or \"\"\n        self._params_model = self._build_params_model()\n\n    @property\n    def name(self) -> str:\n        return self._name\n\n    @property\n    def description(self) -> str:\n        return self._description\n\n    def _build_params_model(self) -> type[BaseModel]:\n        \"\"\"Build a Pydantic model from the function signature.\"\"\"\n        hints = get_type_hints(self._fn, include_extras=True)\n        sig = inspect.signature(self._fn)\n\n        fields: dict[str, Any] = {}\n        for param_name, param in sig.parameters.items():\n            if param_name == \"self\":\n                continue\n\n            annotation = hints.get(param_name, Any)\n            base_type, field_info, desc = _extract_annotated_info(annotation)\n\n            if param.default is inspect.Parameter.empty:\n                default = ...\n            else:\n                default = param.default\n\n            if field_info is not None:\n                if field_info.default is PydanticUndefined and default is not ...:\n                    field_info.default = default\n                fields[param_name] = (base_type, field_info)\n            else:\n                fields[param_name] = (base_type, Field(default, description=desc))\n\n        return create_model(f\"{self._name}_params\", **fields)\n\n    @property\n    def parameters_schema(self) -> dict[str, Any]:\n        \"\"\"JSON Schema for tool parameters.\"\"\"\n        return self._params_model.model_json_schema()\n\n    def to_spec(self) -> ToolSpec:\n        \"\"\"Convert to ToolSpec for renderer integration.\"\"\"\n        return {\n            \"name\": self.name,\n            \"description\": self.description,\n            \"parameters\": self.parameters_schema,\n        }\n\n    async def run(self, input: ToolInput) -> ToolResult:\n        \"\"\"Execute the tool with validated arguments. Returns a ToolResult.\"\"\"\n        # Validate arguments\n        try:\n            validated = self._params_model.model_validate(input.arguments)\n        except Exception as e:\n            return error_tool_result(\n                f\"Parameter validation failed: {e}\",\n                call_id=input.call_id or \"\",\n                name=self.name,\n                error_type=\"validation_failed\",\n            )\n\n        # Execute function\n        try:\n            kwargs = validated.model_dump()\n            args = (self._instance,) if self._instance is not None else ()\n            if asyncio.iscoroutinefunction(self._fn):\n                result = await self._fn(*args, **kwargs)\n            else:\n                result = self._fn(*args, **kwargs)\n\n            # Function must return ToolResult\n            if not isinstance(result, ToolResult):\n                raise TypeError(\n                    f\"Tool function '{self.name}' must return ToolResult, \"\n                    f\"got {type(result).__name__}. \"\n                    f\"Use simple_tool_result() helper for simple cases.\"\n                )\n\n            # Fill in call_id and name if not provided\n            for msg in result.messages:\n                if not msg.get(\"tool_call_id\") and input.call_id:\n                    msg[\"tool_call_id\"] = input.call_id\n                if not msg.get(\"name\"):\n                    msg[\"name\"] = self.name\n\n            return result\n\n        except Exception as e:\n            return error_tool_result(\n                f\"Tool execution failed: {e}\",\n                call_id=input.call_id or \"\",\n                name=self.name,\n                error_type=\"execution_failed\",\n            )\n\n    def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool:\n        \"\"\"Descriptor protocol: bind to instance when accessed as method.\"\"\"\n        if obj is None:\n            return self\n        # Create a bound copy\n        bound = FunctionTool.__new__(FunctionTool)\n        bound._fn = self._fn\n        bound._instance = obj\n        bound._name = self._name\n        bound._description = self._description\n        bound._params_model = self._params_model\n        return bound\n\n\ndef tool(fn: Callable[..., Any]) -> FunctionTool:\n    \"\"\"\n    Decorator to create a tool from a function or method.\n\n    The decorated function must return a ToolResult. Use simple_tool_result() helper\n    for basic cases, or construct ToolResult directly for metrics/metadata.\n\n    Usage:\n        @tool\n        async def search(query: Annotated[str, \"The search query\"]) -> ToolResult:\n            '''Search for information.'''\n            results = await do_search(query)\n            return simple_tool_result(\n                json.dumps({\"results\": results}),\n                metrics={\"result_count\": len(results)}\n            )\n\n        # As class method with shared state:\n        class MySharedStateTools:\n            def __init__(self, api_key: str):\n                self.api_key = api_key\n\n            @tool\n            async def search(self, query: Annotated[str, \"Query\"]) -> ToolResult:\n                '''Search for information.'''\n                results = await do_search(query, api_key=self.api_key)\n                return simple_tool_result(json.dumps(results))\n    \"\"\"\n    return FunctionTool(fn)\n\n\nasync def handle_tool_call(\n    tools: dict[str, Tool],\n    tool_call: ToolCall,\n) -> ToolResult:\n    \"\"\"Handle a single tool call, returning a ToolResult.\"\"\"\n    tool_name = tool_call.function.name\n    tool_call_id = tool_call.id or \"\"\n\n    if tool_name not in tools:\n        return error_tool_result(\n            f\"Tool '{tool_name}' not found\",\n            call_id=tool_call_id,\n            name=tool_name,\n            error_type=\"tool_not_found\",\n        )\n\n    tool_obj = tools[tool_name]\n    try:\n        arguments = json.loads(tool_call.function.arguments)\n    except json.JSONDecodeError as e:\n        return error_tool_result(\n            f\"Failed to parse tool arguments: {e}\",\n            call_id=tool_call_id,\n            name=tool_name,\n            error_type=\"json_decode_failed\",\n        )\n\n    return await tool_obj.run(ToolInput(arguments=arguments, call_id=tool_call_id))\n"
  },
  {
    "path": "tinker_cookbook/tool_use/types.py",
    "content": "\"\"\"Core types for tool-use library.\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Protocol, runtime_checkable\n\nfrom tinker_cookbook.renderers.base import Message, ToolSpec\n\n\n@dataclass\nclass ToolInput:\n    \"\"\"Input to a tool invocation.\"\"\"\n\n    arguments: dict[str, Any]\n    call_id: str | None = None\n\n\n@dataclass\nclass ToolResult:\n    \"\"\"Result from a tool invocation.\"\"\"\n\n    messages: list[Message]\n    should_stop: bool = False\n    metrics: dict[str, float] = field(default_factory=dict)\n    metadata: dict[str, Any] = field(default_factory=dict)\n\n\n@runtime_checkable\nclass Tool(Protocol):\n    \"\"\"Protocol for tools that can be used by LLM agents.\"\"\"\n\n    @property\n    def name(self) -> str:\n        \"\"\"Tool name shown to the model.\"\"\"\n        ...\n\n    @property\n    def description(self) -> str:\n        \"\"\"Tool description shown to the model.\"\"\"\n        ...\n\n    @property\n    def parameters_schema(self) -> dict[str, Any]:\n        \"\"\"JSON Schema for tool parameters shown to the model.\"\"\"\n        ...\n\n    async def run(self, input: ToolInput) -> ToolResult:\n        \"\"\"Execute the tool with validated arguments. Returns a ToolResult.\"\"\"\n        ...\n\n    def to_spec(self) -> ToolSpec:\n        \"\"\"Convert to ToolSpec for renderer integration.\"\"\"\n        return {\n            \"name\": self.name,\n            \"description\": self.description,\n            \"parameters\": self.parameters_schema,\n        }\n"
  },
  {
    "path": "tinker_cookbook/utils/__init__.py",
    "content": "\"\"\"Utility helpers for tinker-cookbook.\"\"\"\n"
  },
  {
    "path": "tinker_cookbook/utils/code_state.py",
    "content": "from __future__ import annotations\n\nimport importlib\nimport subprocess\nfrom collections.abc import Sequence\nfrom pathlib import Path\nfrom types import ModuleType\nfrom typing import cast\n\n\ndef code_state(modules: Sequence[str | ModuleType] = (\"tinker_cookbook\",)) -> str:\n    \"\"\"\n    Return a single diff-formatted string that captures the current code state for the\n    provided Python modules. For each module, we:\n\n    - Locate the module on the filesystem\n    - Discover the enclosing Git repository (the module may live inside a larger repo)\n    - Record the current commit hash (HEAD)\n    - Include combined staged+unstaged changes (i.e., diff vs HEAD) for the entire\n      containing Git repository (repo-wide). Subtree diffs are omitted to avoid\n      duplication.\n\n    The output is suitable for storage alongside experiment or training metadata to\n    reproduce the exact code state later. When multiple modules are provided, their\n    sections are concatenated in order.\n\n    Parameters:\n    - modules: sequence of module import names (e.g., \"tinker_cookbook.rl\") or already-\n      imported module objects. All entries must be either `str` or `ModuleType`.\n\n    Returns:\n    - A string beginning with a header per module of the form:\n      \"### module: <module_name> (repo: <repo_root>) @ <commit_hash>\" followed by\n      the staged and unstaged `git diff` outputs restricted to that module directory.\n      If a module is not in a Git repository, a short note is included instead.\n    \"\"\"\n\n    def ensure_module(obj: str | ModuleType) -> ModuleType:\n        if isinstance(obj, ModuleType):\n            return obj\n        assert isinstance(obj, str), (\n            \"Each item in modules must be a module object or import path string\"\n        )\n        return importlib.import_module(obj)\n\n    def find_module_dir(mod: ModuleType) -> Path:\n        # Prefer package path if available, else use the file's directory\n        mod_file = cast(str | None, getattr(mod, \"__file__\", None))\n        mod_path_list = cast(Sequence[str] | None, getattr(mod, \"__path__\", None))\n        assert (mod_file is not None) or (mod_path_list is not None), (\n            f\"Module {mod!r} lacks __file__/__path__\"\n        )\n        if mod_path_list is not None:  # packages expose __path__ (iterable); pick the first entry\n            first_path = next(iter(mod_path_list))\n            return Path(first_path).resolve()\n        assert mod_file is not None\n        return Path(mod_file).resolve().parent\n\n    def git_toplevel(start_dir: Path) -> Path | None:\n        try:\n            completed = subprocess.run(\n                [\"git\", \"-C\", str(start_dir), \"rev-parse\", \"--show-toplevel\"],\n                check=True,\n                capture_output=True,\n                text=True,\n            )\n            return Path(completed.stdout.strip()).resolve()\n        except subprocess.CalledProcessError:\n            return None\n\n    def git_rev(head_dir: Path) -> str:\n        completed = subprocess.run(\n            [\"git\", \"-C\", str(head_dir), \"rev-parse\", \"HEAD\"],\n            check=True,\n            capture_output=True,\n            text=True,\n        )\n        return completed.stdout.strip()\n\n    def git_diff_vs_head(head_dir: Path) -> str:\n        \"\"\"Return a repo-wide unified diff of working tree + index (staged and\n        unstaged) relative to HEAD.\"\"\"\n        args = [\"git\", \"-C\", str(head_dir), \"diff\", \"--no-color\", \"HEAD\"]\n        completed = subprocess.run(args, check=False, capture_output=True, text=True)\n        return completed.stdout\n\n    sections: list[str] = []\n\n    # Group modules by their enclosing repo and track non-git modules\n    repos_to_modules: dict[Path, list[str]] = {}\n    nongit_modules: list[tuple[str, Path]] = []\n\n    for obj in modules:\n        mod = ensure_module(obj)\n        mod_name = mod.__name__\n        mod_dir = find_module_dir(mod)\n        repo_root = git_toplevel(mod_dir)\n\n        if repo_root is None:\n            nongit_modules.append((mod_name, mod_dir))\n            continue\n\n        if repo_root not in repos_to_modules:\n            repos_to_modules[repo_root] = []\n        if mod_name not in repos_to_modules[repo_root]:\n            repos_to_modules[repo_root].append(mod_name)\n\n    # Emit one section per repo with a single repo-wide diff\n    for repo_root in sorted(repos_to_modules.keys(), key=lambda p: str(p)):\n        try:\n            head = git_rev(repo_root)\n        except subprocess.CalledProcessError:\n            head = \"UNKNOWN\"\n\n        diff_repo = git_diff_vs_head(repo_root)\n        mod_names = \", \".join(sorted(repos_to_modules[repo_root]))\n        header = f\"### repo: {repo_root} @ {head}\\nmodules: {mod_names}\\n\"\n        if diff_repo:\n            body = \"-- repo-wide (vs HEAD, staged+unstaged) --\\n\" + diff_repo.rstrip() + \"\\n\"\n        else:\n            body = \"(no local changes)\\n\"\n        sections.append(header + body)\n\n    # Notes for modules not in a git repo\n    for mod_name, mod_dir in nongit_modules:\n        sections.append(\n            f\"### module: {mod_name} (not in a git repository)\\nmodule_path: {mod_dir}\\n\"\n        )\n\n    return \"\\n\".join(sections)\n"
  },
  {
    "path": "tinker_cookbook/utils/deprecation.py",
    "content": "\"\"\"\nDeprecation utilities for managing API evolution in tinker-cookbook.\n\nThis module provides tools for deprecating functions, classes, parameters,\nand module-level attributes with clear migration guidance and automatic\nenforcement when the removal version is reached.\n\nUsage examples::\n\n    from tinker_cookbook.utils.deprecation import deprecated, warn_deprecated\n\n    # Deprecate an entire function or class\n    @deprecated(message=\"Use new_func() instead.\", removal_version=\"0.20.0\")\n    def old_func(x):\n        return new_func(x)\n\n    # Deprecate inside a function body (e.g., a parameter)\n    def train(*, lr, learning_rate=None):\n        if learning_rate is not None:\n            warn_deprecated(\n                \"learning_rate\",\n                removal_version=\"0.20.0\",\n                message=\"Use the 'lr' parameter instead.\",\n            )\n            lr = learning_rate\n        ...\n\n    # Deprecate a module-level attribute (put in the module's __init__.py)\n    __getattr__ = make_deprecated_module_getattr(\n        __name__,\n        {\"OldClass\": (\"new_module.NewClass\", \"0.20.0\")},\n    )\n\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nimport importlib\nimport importlib.metadata\nimport warnings\nfrom collections.abc import Callable\nfrom typing import Any, TypeVar, overload\n\nF = TypeVar(\"F\", bound=Callable[..., Any])\n\n\ndef _parse_version(v: str) -> tuple[int, ...]:\n    \"\"\"Parse a version string into a comparable tuple of ints.\n\n    Strips any pre-release/dev suffixes (e.g. ``\"0.15.0.dev3+g1234\"`` becomes\n    ``(0, 15, 0)``).  This avoids a dependency on ``packaging``.\n    \"\"\"\n    parts: list[int] = []\n    for segment in v.split(\".\"):\n        digits = \"\"\n        for ch in segment:\n            if ch.isdigit():\n                digits += ch\n            else:\n                break\n        if digits:\n            parts.append(int(digits))\n        else:\n            break\n    return tuple(parts) if parts else (0,)\n\n\ndef _current_version() -> tuple[int, ...]:\n    \"\"\"Return the current package version as a comparable tuple.\"\"\"\n    try:\n        raw = importlib.metadata.version(\"tinker_cookbook\")\n        return _parse_version(raw)\n    except Exception:\n        return (0, 0, 0)\n\n\ndef _check_past_removal(removal_version: str | None) -> bool:\n    \"\"\"Return True if the current version is at or past the removal version.\"\"\"\n    if removal_version is None:\n        return False\n    try:\n        return _current_version() >= _parse_version(removal_version)\n    except Exception:\n        return False\n\n\ndef warn_deprecated(\n    name: str,\n    *,\n    removal_version: str | None = None,\n    message: str = \"\",\n    stacklevel: int = 2,\n) -> None:\n    \"\"\"Emit a DeprecationWarning for a deprecated feature.\n\n    If the current package version is at or past *removal_version*, raises\n    a ``RuntimeError`` instead so that stale deprecated code paths are not\n    silently used after their intended removal date.\n\n    Args:\n        name: Short identifier for the deprecated feature (e.g. function name,\n            parameter name).\n        removal_version: The version in which this feature will be removed.\n            When the running version reaches this value the warning becomes\n            a hard error.  Pass ``None`` to warn without a scheduled removal.\n        message: Additional guidance, typically a migration path such as\n            \"Use X instead.\"\n        stacklevel: Passed through to ``warnings.warn``. The default of 2\n            points at the caller of the function that calls ``warn_deprecated``.\n    \"\"\"\n    parts: list[str] = [f\"'{name}' is deprecated.\"]\n    if removal_version is not None:\n        parts.append(f\"It will be removed in version {removal_version}.\")\n    if message:\n        parts.append(message)\n    full_message = \" \".join(parts)\n\n    if _check_past_removal(removal_version):\n        raise RuntimeError(\n            f\"{full_message} (Current version is \"\n            f\"{'.'.join(str(x) for x in _current_version())}; \"\n            f\"this should have been removed by {removal_version}.)\"\n        )\n\n    warnings.warn(full_message, DeprecationWarning, stacklevel=stacklevel)\n\n\n@overload\ndef deprecated(__func: F) -> F: ...\n\n\n@overload\ndef deprecated(\n    *,\n    message: str = ...,\n    removal_version: str | None = ...,\n) -> Callable[[F], F]: ...\n\n\ndef deprecated(\n    _func: Callable[..., Any] | None = None,\n    *,\n    message: str = \"\",\n    removal_version: str | None = None,\n) -> Any:\n    \"\"\"Decorator to mark a function or class as deprecated.\n\n    Can be used with or without arguments::\n\n        @deprecated\n        def old(): ...\n\n        @deprecated(message=\"Use new_func instead.\", removal_version=\"0.20.0\")\n        def old(): ...\n\n    When applied to a class, the warning is emitted at instantiation time.\n    \"\"\"\n\n    def decorator(obj: F) -> F:\n        obj_name: str = getattr(obj, \"__qualname__\", getattr(obj, \"__name__\", str(obj)))\n        kind = \"Class\" if isinstance(obj, type) else \"Function\"\n\n        if isinstance(obj, type):\n            original_init: Callable[..., None] = obj.__init__  # type: ignore[misc]\n\n            @functools.wraps(original_init)\n            def new_init(self: Any, *args: Any, **kwargs: Any) -> None:\n                warn_deprecated(\n                    f\"{kind} {obj_name}\",\n                    removal_version=removal_version,\n                    message=message,\n                    stacklevel=2,\n                )\n                original_init(self, *args, **kwargs)\n\n            obj.__init__ = new_init  # type: ignore[misc]\n            return obj  # type: ignore[return-value]\n        else:\n\n            @functools.wraps(obj)\n            def wrapper(*args: Any, **kwargs: Any) -> Any:\n                warn_deprecated(\n                    f\"{kind} {obj_name}\",\n                    removal_version=removal_version,\n                    message=message,\n                    stacklevel=2,\n                )\n                return obj(*args, **kwargs)\n\n            return wrapper  # type: ignore[return-value]\n\n    if _func is not None:\n        return decorator(_func)\n\n    return decorator\n\n\ndef make_deprecated_module_getattr(\n    module_name: str,\n    attrs: dict[str, tuple[str, str | None]],\n) -> Callable[[str], Any]:\n    \"\"\"Create a ``__getattr__`` function for deprecating module-level attributes.\n\n    Returns a function suitable for assigning to ``__getattr__`` at module scope.\n    When an old attribute name is accessed, it emits a deprecation warning and\n    transparently returns the new object.\n\n    Args:\n        module_name: ``__name__`` of the module defining ``__getattr__``.\n        attrs: Mapping of ``{old_name: (dotted_path_to_new, removal_version)}``.\n            *dotted_path_to_new* is ``\"package.module.NewName\"`` and will be\n            imported and returned.  *removal_version* may be ``None``.\n\n    Returns:\n        A ``__getattr__`` function.\n\n    Example::\n\n        # In mymodule/__init__.py\n        __getattr__ = make_deprecated_module_getattr(\n            __name__,\n            {\"OldThing\": (\"mymodule.new_place.NewThing\", \"0.20.0\")},\n        )\n    \"\"\"\n\n    def __getattr__(name: str) -> Any:\n        if name not in attrs:\n            raise AttributeError(f\"module {module_name!r} has no attribute {name!r}\")\n\n        new_path, removal_version = attrs[name]\n\n        module_path, _, attr_name = new_path.rpartition(\".\")\n        if not module_path:\n            raise ValueError(\n                f\"make_deprecated_module_getattr: new path {new_path!r} must be a \"\n                f\"dotted path (e.g. 'package.module.Name')\"\n            )\n\n        mod = importlib.import_module(module_path)\n        replacement = getattr(mod, attr_name)\n\n        warn_deprecated(\n            f\"{module_name}.{name}\",\n            removal_version=removal_version,\n            message=f\"Use {new_path} instead.\",\n            stacklevel=2,\n        )\n        return replacement\n\n    return __getattr__\n"
  },
  {
    "path": "tinker_cookbook/utils/deprecation_test.py",
    "content": "\"\"\"Tests for the deprecation utilities.\"\"\"\n\nfrom __future__ import annotations\n\nimport warnings\nfrom unittest.mock import patch\n\nimport pytest\n\nfrom tinker_cookbook.utils.deprecation import (\n    _parse_version,\n    deprecated,\n    make_deprecated_module_getattr,\n    warn_deprecated,\n)\n\n# ---------------------------------------------------------------------------\n# _parse_version\n# ---------------------------------------------------------------------------\n\n\nclass TestParseVersion:\n    def test_simple(self):\n        assert _parse_version(\"1.2.3\") == (1, 2, 3)\n\n    def test_dev_suffix(self):\n        assert _parse_version(\"0.15.0.dev3+g1234\") == (0, 15, 0)\n\n    def test_prerelease(self):\n        assert _parse_version(\"1.0.0rc1\") == (1, 0, 0)\n\n    def test_single_number(self):\n        assert _parse_version(\"42\") == (42,)\n\n    def test_empty_fallback(self):\n        assert _parse_version(\"\") == (0,)\n\n\n# ---------------------------------------------------------------------------\n# warn_deprecated\n# ---------------------------------------------------------------------------\n\n\nclass TestWarnDeprecated:\n    def test_basic_warning(self):\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            warn_deprecated(\"old_func\")\n        assert len(w) == 1\n        assert issubclass(w[0].category, DeprecationWarning)\n        assert \"'old_func' is deprecated.\" in str(w[0].message)\n\n    def test_warning_with_removal_version(self):\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            warn_deprecated(\"old_func\", removal_version=\"99.0.0\")\n        assert \"removed in version 99.0.0\" in str(w[0].message)\n\n    def test_warning_with_message(self):\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            warn_deprecated(\"old_func\", message=\"Use new_func() instead.\")\n        assert \"Use new_func() instead.\" in str(w[0].message)\n\n    def test_full_warning_message(self):\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            warn_deprecated(\n                \"my_feature\",\n                removal_version=\"99.0.0\",\n                message=\"Use better_feature instead.\",\n            )\n        msg = str(w[0].message)\n        assert \"'my_feature' is deprecated.\" in msg\n        assert \"removed in version 99.0.0\" in msg\n        assert \"Use better_feature instead.\" in msg\n\n    def test_past_removal_version_raises(self):\n        with patch(\"tinker_cookbook.utils.deprecation._current_version\") as mock_ver:\n            mock_ver.return_value = (1, 0, 0)\n            with pytest.raises(RuntimeError, match=\"should have been removed\"):\n                warn_deprecated(\"old_func\", removal_version=\"0.5.0\")\n\n    def test_no_removal_version_never_raises(self):\n        \"\"\"When removal_version is None, it always warns, never raises.\"\"\"\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            warn_deprecated(\"old_func\", removal_version=None)\n        assert len(w) == 1\n\n\n# ---------------------------------------------------------------------------\n# @deprecated decorator\n# ---------------------------------------------------------------------------\n\n\nclass TestDeprecatedDecorator:\n    def test_decorate_function_no_args(self):\n        @deprecated\n        def old_func() -> str:\n            return \"result\"\n\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            result = old_func()\n\n        assert result == \"result\"\n        assert len(w) == 1\n        assert issubclass(w[0].category, DeprecationWarning)\n        assert \"old_func\" in str(w[0].message)\n\n    def test_decorate_function_with_message(self):\n        @deprecated(message=\"Use new_func instead.\", removal_version=\"99.0.0\")\n        def old_func() -> str:\n            return \"result\"\n\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            result = old_func()\n\n        assert result == \"result\"\n        assert \"Use new_func instead.\" in str(w[0].message)\n        assert \"99.0.0\" in str(w[0].message)\n\n    def test_decorate_function_empty_parens(self):\n        @deprecated()\n        def old_func() -> str:\n            return \"result\"\n\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            result = old_func()\n\n        assert result == \"result\"\n        assert len(w) == 1\n\n    def test_preserves_function_metadata(self):\n        @deprecated(message=\"msg\", removal_version=\"99.0.0\")\n        def old_func() -> str:\n            \"\"\"Original docstring.\"\"\"\n            return \"result\"\n\n        assert old_func.__name__ == \"old_func\"\n        assert old_func.__doc__ == \"Original docstring.\"\n\n    def test_decorate_class(self):\n        @deprecated(message=\"Use NewClass instead.\", removal_version=\"99.0.0\")\n        class OldClass:\n            def __init__(self, x: int):\n                self.x = x\n\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            obj = OldClass(42)\n\n        assert obj.x == 42\n        assert len(w) == 1\n        assert \"OldClass\" in str(w[0].message)\n        assert \"Use NewClass instead.\" in str(w[0].message)\n\n    def test_class_preserves_isinstance(self):\n        @deprecated(message=\"Use NewClass instead.\")\n        class OldClass:\n            pass\n\n        with warnings.catch_warnings(record=True):\n            warnings.simplefilter(\"always\")\n            obj = OldClass()\n\n        assert isinstance(obj, OldClass)\n\n    def test_function_with_args_and_kwargs(self):\n        @deprecated(message=\"msg\")\n        def add(a: int, b: int, *, extra: int = 0) -> int:\n            return a + b + extra\n\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            result = add(1, 2, extra=3)\n\n        assert result == 6\n        assert len(w) == 1\n\n    def test_past_removal_raises_on_call(self):\n        @deprecated(message=\"Use new.\", removal_version=\"0.1.0\")\n        def old_func() -> str:\n            return \"result\"\n\n        with patch(\"tinker_cookbook.utils.deprecation._current_version\") as mock_ver:\n            mock_ver.return_value = (1, 0, 0)\n            with pytest.raises(RuntimeError, match=\"should have been removed\"):\n                old_func()\n\n\n# ---------------------------------------------------------------------------\n# make_deprecated_module_getattr\n# ---------------------------------------------------------------------------\n\n\nclass TestMakeDeprecatedModuleGetattr:\n    def test_unknown_attr_raises_attribute_error(self):\n        getattr_fn = make_deprecated_module_getattr(\"mymod\", {})\n        with pytest.raises(AttributeError, match=\"has no attribute\"):\n            getattr_fn(\"nonexistent\")\n\n    def test_redirects_with_warning(self):\n        getattr_fn = make_deprecated_module_getattr(\n            \"mymod\",\n            {\"OldPath\": (\"os.path.join\", \"99.0.0\")},\n        )\n\n        with warnings.catch_warnings(record=True) as w:\n            warnings.simplefilter(\"always\")\n            result = getattr_fn(\"OldPath\")\n\n        import os.path\n\n        assert result is os.path.join\n        assert len(w) == 1\n        assert \"mymod.OldPath\" in str(w[0].message)\n        assert \"os.path.join\" in str(w[0].message)\n\n    def test_bad_path_raises(self):\n        getattr_fn = make_deprecated_module_getattr(\n            \"mymod\",\n            {\"Bad\": (\"NoDots\", \"99.0.0\")},\n        )\n        with pytest.raises(ValueError, match=\"dotted path\"):\n            getattr_fn(\"Bad\")\n"
  },
  {
    "path": "tinker_cookbook/utils/file_utils.py",
    "content": "import json\n\n\ndef read_jsonl(path: str) -> list[dict]:\n    with open(path) as f:\n        return [json.loads(line) for line in f]\n"
  },
  {
    "path": "tinker_cookbook/utils/format_colorized.py",
    "content": "from termcolor import colored\n\nfrom tinker_cookbook.tokenizer_utils import Tokenizer\n\n\ndef format_colorized(\n    tokens: list[int], weights: list[float], tokenizer: Tokenizer, draw_newline_arrow: bool = False\n) -> str:\n    \"\"\"\n    Colour-code text according to per-token weights.\n\n    * Cyan text  → weight > 0\n    * Yellow text  → weight = 0\n    * Red text   → weight < 0\n\n    The function minimises ANSI escape sequences by wrapping *runs* of\n    like-coloured tokens, and decodes each run in a single call so that\n    multi-byte or multibyte-character languages (e.g. CJK) render correctly.\n    \"\"\"\n    if len(tokens) != len(weights):\n        raise ValueError(\"`tokens` and `weights` must be the same length.\")\n\n    chunks, current_ids, current_color = [], [], None\n\n    def flush_current_run():\n        decoded = str(tokenizer.decode(current_ids))\n        lines = decoded.splitlines(keepends=True)\n        for line in lines:\n            if draw_newline_arrow:\n                line = line.replace(\"\\n\", \"↵\\n\")\n            chunks.append(colored(line, current_color))\n\n    for tok_id, w in zip(tokens, weights, strict=True):\n        if w < 0:\n            color = \"red\"\n        elif w == 0:\n            color = \"yellow\"\n        else:\n            color = \"green\"\n\n        # Flush when the colour changes\n        if color != current_color and current_ids:\n            flush_current_run()\n            current_ids = []\n\n        current_ids.append(tok_id)\n        current_color = color\n\n    flush_current_run()\n\n    return \"\".join(chunks)\n"
  },
  {
    "path": "tinker_cookbook/utils/logtree.py",
    "content": "\"\"\"\nLogtree: Scope-based logging library for creating nested HTML reports.\n\nThis module provides a context-based API for generating structured HTML logs\nthat reflect the call tree of your code. Ideal for logging RL rollouts,\nmodel evaluations, and other hierarchical computations.\n\nExample usage:\n    import logtree\n\n    async def train_iteration():\n        with logtree.init_trace(\"Training Iteration 1\", path=\"output.html\"):\n            with logtree.scope_header(\"Sampling\"):\n                logtree.log_text(\"Generated 100 samples\")\n                await sample_data()\n\n            with logtree.scope_header(\"Training\"):\n                logtree.log_text(\"Loss: 0.42\")\n                await train_model()\n\"\"\"\n\nimport functools\nimport html as html_module\nimport inspect\nimport json\nimport os\nimport traceback\nfrom collections.abc import Callable, Iterator, Mapping, Sequence\nfrom contextlib import contextmanager\nfrom contextvars import ContextVar\nfrom dataclasses import dataclass, field\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any, Protocol, TypeVar, cast, overload\n\n# Context variables for task-local state\n_current_trace: ContextVar[\"Trace | None\"] = ContextVar(\"lt_current_trace\", default=None)\n_container_stack: ContextVar[\"tuple[Node, ...]\"] = ContextVar(\"lt_container_stack\", default=())\n_header_depth: ContextVar[\"tuple[int, ...]\"] = ContextVar(\"lt_header_depth\", default=())\n_logging_disabled: ContextVar[bool] = ContextVar(\"lt_logging_disabled\", default=False)\n\n\nclass Formatter(Protocol):\n    \"\"\"Protocol for objects that can format themselves as HTML with CSS.\n\n    Optionally implement ``to_data`` to attach structured data to the JSON\n    export (avoids consumers having to parse raw HTML).\n    \"\"\"\n\n    def to_html(self) -> str:\n        \"\"\"Generate HTML representation of this object.\"\"\"\n        ...\n\n    def get_css(self) -> str:\n        \"\"\"Get CSS needed to style this object's HTML.\"\"\"\n        ...\n\n    def to_data(self) -> dict[str, Any] | None:\n        \"\"\"Return structured data for JSON export, or None.\"\"\"\n        return None\n\n\n@dataclass\nclass Node:\n    \"\"\"Represents an HTML element in the tree.\"\"\"\n\n    tag: str\n    attrs: dict[str, str] = field(default_factory=dict)\n    children: list[\"Node | str\"] = field(default_factory=list)\n    # Optional structured data attached by formatters.\n    # Included in JSON export so consumers can extract typed content\n    # (e.g., conversation messages) without parsing raw HTML.\n    data: dict[str, Any] | None = field(default=None, repr=False)\n\n    def to_html(self, indent: int = 0) -> str:\n        \"\"\"Convert node to HTML string.\"\"\"\n        ind = \"  \" * indent\n        attrs_str = \"\".join(\n            f' {k}=\"{html_module.escape(v, quote=True)}\"' for k, v in self.attrs.items()\n        )\n\n        if not self.children:\n            return f\"{ind}<{self.tag}{attrs_str}></{self.tag}>\\n\"\n\n        # Keep simple text-only nodes on one line to avoid extra rendered\n        # whitespace when CSS uses `white-space: pre-wrap` (e.g., lt-p).\n        if all(isinstance(child, str) for child in self.children):\n            text = \"\".join(child for child in self.children if isinstance(child, str))\n            return f\"{ind}<{self.tag}{attrs_str}>{text}</{self.tag}>\\n\"\n\n        lines = [f\"{ind}<{self.tag}{attrs_str}>\\n\"]\n        for child in self.children:\n            if isinstance(child, str):\n                lines.append(child)\n            else:\n                lines.append(child.to_html(indent + 1))\n        lines.append(f\"{ind}</{self.tag}>\\n\")\n        return \"\".join(lines)\n\n    def to_dict(self) -> dict[str, Any]:\n        \"\"\"Convert node to a JSON-serializable dictionary.\n\n        When ``data`` is present, raw HTML string children are omitted from the\n        JSON — consumers should use ``data`` instead.\n        \"\"\"\n        if self.data is not None:\n            children = [child.to_dict() for child in self.children if isinstance(child, Node)]\n        else:\n            children = [\n                child if isinstance(child, str) else child.to_dict() for child in self.children\n            ]\n        d: dict[str, Any] = {\n            \"tag\": self.tag,\n            \"attrs\": dict(self.attrs),\n            \"children\": children,\n        }\n        if self.data is not None:\n            d[\"data\"] = self.data\n        return d\n\n\n@dataclass\nclass Theme:\n    \"\"\"Theme configuration for HTML output.\"\"\"\n\n    css_text: str | None = None  # Custom CSS; if None, use built-in\n    css_urls: list[str] = field(default_factory=list)\n    css_vars: dict[str, str] = field(default_factory=dict)  # CSS custom properties\n\n\nclass Trace:\n    \"\"\"Root trace object representing an HTML document.\"\"\"\n\n    def __init__(self, title: str, path: str | os.PathLike | None, write_on_error: bool):\n        self.title = title\n        self.path = Path(path) if path is not None else None\n        self.write_on_error = write_on_error\n        self.started_at = datetime.now()\n        self.root = Node(\"body\", {\"class\": \"lt-root\"})\n        self._formatter_css: set[str] = set()  # Deduplicated CSS from formatters\n\n    def _register_formatter_css(self, css: str) -> None:\n        \"\"\"Register CSS from a formatter (deduplicated per trace).\"\"\"\n        if css:\n            self._formatter_css.add(css)\n\n    def body_html(self, wrap_body: bool = True) -> str:\n        \"\"\"Get the body HTML.\"\"\"\n        inner = self.root.to_html(indent=0)\n        if wrap_body:\n            return inner\n        else:\n            # Return just the inner content\n            return \"\\n\".join(\n                line for line in inner.split(\"\\n\") if \"<body\" not in line and \"</body>\" not in line\n            )\n\n    def get_html(self) -> str:\n        \"\"\"Alias for body_html().\"\"\"\n        return self.body_html(wrap_body=True)\n\n    def head_html(\n        self, theme: Theme | None = None, title: str | None = None, extra_head: str | None = None\n    ) -> str:\n        \"\"\"Generate the <head> section of the HTML document.\"\"\"\n        if theme is None:\n            theme = Theme()\n\n        parts = []\n        parts.append(f\"<title>{html_module.escape(title or self.title)}</title>\")\n        parts.append('<meta charset=\"UTF-8\">')\n        parts.append('<meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">')\n\n        # External CSS\n        for url in theme.css_urls:\n            parts.append(f'<link rel=\"stylesheet\" href=\"{html_module.escape(url, quote=True)}\">')\n\n        # Inline CSS\n        css = theme.css_text if theme.css_text is not None else _DEFAULT_CSS\n        if css or self._formatter_css:\n            parts.append(\"<style>\")\n            if css:\n                parts.append(css)\n            # CSS custom properties\n            if theme.css_vars:\n                parts.append(\":root {\")\n                for key, value in theme.css_vars.items():\n                    parts.append(f\"  {key}: {value};\")\n                parts.append(\"}\")\n            # Formatter CSS\n            if self._formatter_css:\n                parts.append(\"\\n/* Formatter CSS */\")\n                for formatter_css in self._formatter_css:\n                    parts.append(formatter_css)\n            parts.append(\"</style>\")\n\n        if extra_head:\n            parts.append(extra_head)\n\n        return \"\\n\".join(parts)\n\n    def to_dict(self) -> dict[str, Any]:\n        \"\"\"Convert the trace to a JSON-serializable dictionary.\"\"\"\n        return {\n            \"title\": self.title,\n            \"started_at\": self.started_at.isoformat(),\n            \"path\": str(self.path) if self.path is not None else None,\n            \"root\": self.root.to_dict(),\n        }\n\n\n# Default CSS styling\n_DEFAULT_CSS = \"\"\"\nbody {\n    font-family: -apple-system, BlinkMacSystemFont, \"Segoe UI\", Roboto, sans-serif;\n    line-height: 1.45;\n    max-width: 1200px;\n    margin: 0 auto;\n    padding: 14px;\n    background: var(--lt-bg, #f5f5f5);\n    color: var(--lt-text, #333);\n}\n\n.lt-root {\n    background: var(--lt-card, white);\n    padding: 1.2rem 1.4rem;\n    border-radius: 8px;\n    box-shadow: 0 2px 4px rgba(0,0,0,0.1);\n}\n\n.lt-title {\n    margin: 0 0 0.5rem 0;\n    color: var(--lt-accent, #2563eb);\n    border-bottom: 2px solid var(--lt-border, #e5e7eb);\n    padding-bottom: 0.5rem;\n}\n\n.lt-subtitle {\n    color: var(--lt-sub, #666);\n    font-size: 0.875rem;\n    margin-bottom: 1.2rem;\n}\n\n.lt-section {\n    margin: 0.95rem 0;\n    padding-left: 0.75rem;\n    border-left: 2px solid var(--lt-border, #e5e7eb);\n}\n\n.lt-section-body {\n    margin-top: 0.12rem;\n}\n\n.lt-section h2, .lt-section h3, .lt-section h4, .lt-section h5, .lt-section h6 {\n    margin: 0.2rem 0;\n    line-height: 1.3;\n    color: var(--lt-accent, #2563eb);\n}\n\n.lt-h2 { font-size: 1.15rem; }\n.lt-h3 { font-size: 1.05rem; }\n.lt-h4 { font-size: 0.98rem; }\n.lt-h5 { font-size: 0.95rem; }\n.lt-h6 { font-size: 0.92rem; }\n\n.lt-p {\n    margin: 0.2rem 0;\n    white-space: pre-wrap;\n}\n\n.lt-details {\n    margin: 0.35rem 0;\n    border: 1px solid var(--lt-border, #e5e7eb);\n    border-radius: 4px;\n    padding: 0.35rem 0.45rem;\n}\n\n.lt-details summary {\n    cursor: pointer;\n    font-weight: 600;\n    user-select: none;\n}\n\n.lt-details-body {\n    margin-top: 0.25rem;\n    padding: 0.35rem 0.45rem;\n    background: var(--lt-bg, #f5f5f5);\n    border-radius: 4px;\n    overflow-x: auto;\n}\n\n.lt-details-body pre {\n    margin: 0;\n    font-family: var(--lt-mono, \"Courier New\", monospace);\n    font-size: 0.875rem;\n    white-space: pre-wrap;\n}\n\n.lt-table {\n    border-collapse: separate;\n    border-spacing: 0;\n    width: 100%;\n    margin: 0.6rem 0;\n    font-size: 0.875rem;\n    border: 1px solid var(--lt-border, #d5dbe3);\n    border-radius: 6px;\n    background: #fff;\n    overflow: hidden;\n}\n\n.lt-table th {\n    background: var(--lt-table-head-bg, #eef2f7);\n    color: var(--lt-text, #1f2937);\n    padding: 0.5rem;\n    text-align: left;\n    font-weight: 600;\n    border-bottom: 1px solid var(--lt-border, #d5dbe3);\n}\n\n.lt-table td {\n    padding: 0.5rem;\n    border-bottom: 1px solid var(--lt-border, #e5e7eb);\n}\n\n.lt-table tr:nth-child(even) {\n    background: #f8fafc;\n}\n\n.lt-table-caption {\n    font-weight: 600;\n    margin-bottom: 0.5rem;\n    color: var(--lt-text, #333);\n}\n\n.lt-exc {\n    background: #fee;\n    border: 2px solid #c00;\n    border-radius: 4px;\n    padding: 0.75rem;\n    margin: 0.5rem 0;\n}\n\n.lt-exc summary {\n    color: #c00;\n    font-weight: 700;\n    cursor: pointer;\n}\n\n.lt-exc pre {\n    margin-top: 0.5rem;\n    font-family: var(--lt-mono, \"Courier New\", monospace);\n    font-size: 0.875rem;\n    overflow-x: auto;\n}\n\n.answer, .reward {\n    font-weight: 600;\n    padding: 0.25rem 0.5rem;\n    border-radius: 4px;\n    display: inline-block;\n    margin: 0.25rem 0;\n}\n\n.answer {\n    background: #dbeafe;\n    color: #1e40af;\n}\n\n.reward {\n    background: #dcfce7;\n    color: #166534;\n}\n\"\"\"\n\n\n# Helper functions\n\n\ndef _normalize_attrs(**attrs: Any) -> dict[str, str]:\n    \"\"\"Normalize attribute names (class_ -> class, data__foo -> data-foo).\"\"\"\n    result = {}\n    for key, value in attrs.items():\n        if key == \"class_\":\n            key = \"class\"\n        elif key.startswith(\"data__\"):\n            key = key.replace(\"__\", \"-\", 1)\n        result[key] = str(value)\n    return result\n\n\ndef _append(node: Node) -> None:\n    \"\"\"Append a node to the current container.\"\"\"\n    stack = _container_stack.get()\n    if not stack:\n        raise RuntimeError(\"No active container to append to\")\n    stack[-1].children.append(node)\n\n\ndef _next_header_level() -> int:\n    \"\"\"Get the next header level based on current depth.\"\"\"\n    depth = _header_depth.get()\n    current = depth[-1] if depth else 1\n    return min(6, current + 1)\n\n\ndef _is_logging_enabled() -> bool:\n    \"\"\"Check if logging is currently enabled.\"\"\"\n    return _current_trace.get() is not None and not _logging_disabled.get()\n\n\n@contextmanager\ndef _in_container(node: Node) -> Iterator[None]:\n    \"\"\"Context manager to push/pop a container.\"\"\"\n    token = _container_stack.set(_container_stack.get() + (node,))\n    try:\n        yield\n    finally:\n        _container_stack.reset(token)\n\n\ndef _exception_block(exc: BaseException) -> Node:\n    \"\"\"Create a details block for an exception.\"\"\"\n    tb_str = \"\".join(traceback.format_exception(type(exc), exc, exc.__traceback__))\n    details_node = Node(\"details\", {\"class\": \"lt-exc\", \"open\": \"open\"})\n    details_node.children.append(Node(\"summary\", {}, [f\"Exception: {type(exc).__name__}: {exc}\"]))\n    pre_node = Node(\"pre\", {})\n    pre_node.children.append(html_module.escape(tb_str))\n    details_node.children.append(pre_node)\n    return details_node\n\n\ndef _write_trace(trace: Trace, theme: Theme | None = None) -> None:\n    \"\"\"Write the trace to disk.\"\"\"\n    if trace.path is None:\n        return\n\n    trace.path.parent.mkdir(parents=True, exist_ok=True)\n\n    with open(trace.path, \"w\") as f:\n        f.write(\"<!doctype html>\\n\")\n        f.write('<html lang=\"en\">\\n')\n        f.write(\"<head>\\n\")\n        f.write(trace.head_html(theme=theme))\n        f.write(\"</head>\\n\")\n        f.write(trace.body_html(wrap_body=True))\n        f.write(\"</html>\\n\")\n\n\n# Public API: Trace lifecycle\n\n\n@contextmanager\ndef init_trace(\n    title: str, path: str | os.PathLike | None = None, *, write_on_error: bool = True\n) -> Iterator[Trace]:\n    \"\"\"\n    Initialize a new trace context.\n\n    Args:\n        title: Title for the HTML document (becomes <h1>)\n        path: Path to write HTML file (None = don't write automatically)\n        write_on_error: If True, write partial HTML even if exception occurs\n\n    Example:\n        with logtree.init_trace(\"My Report\", path=\"output.html\"):\n            logtree.log_text(\"Hello world\")\n    \"\"\"\n    trace = Trace(title, path, write_on_error=write_on_error)\n\n    tok_t = _current_trace.set(trace)\n    tok_s = _container_stack.set((trace.root,))\n    tok_h = _header_depth.set((1,))\n\n    # Emit title and subtitle\n    _append(Node(\"h1\", {\"class\": \"lt-title\"}, [html_module.escape(title)]))\n    _append(\n        Node(\n            \"div\",\n            {\"class\": \"lt-subtitle\"},\n            [f\"Generated {trace.started_at.isoformat(timespec='seconds')}\"],\n        )\n    )\n\n    try:\n        yield trace\n    except BaseException as e:\n        _append(_exception_block(e))\n        if write_on_error and trace.path is not None:\n            _write_trace(trace)\n        raise\n    else:\n        if trace.path is not None:\n            _write_trace(trace)\n    finally:\n        # Always reset context even on exception\n        _header_depth.reset(tok_h)\n        _container_stack.reset(tok_s)\n        _current_trace.reset(tok_t)\n\n\n@contextmanager\ndef scope_header(title: str, **attrs: Any) -> Iterator[None]:\n    \"\"\"\n    Open a section with an auto-leveled header.\n\n    Args:\n        title: Text for the header\n        **attrs: HTML attributes (use class_=\"foo\" for class, data__x=\"y\" for data-x)\n\n    Example:\n        with logtree.scope_header(\"Results\", class_=\"important\"):\n            logtree.log_text(\"Success rate: 95%\")\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        yield\n        return\n\n    section = Node(\"section\", {\"class\": \"lt-section\", **_normalize_attrs(**attrs)})\n    _append(section)\n\n    tok_h = None\n    try:\n        with _in_container(section):\n            h = _next_header_level()\n            _append(Node(f\"h{h}\", {\"class\": f\"lt-h{h}\"}, [html_module.escape(title)]))\n\n            # Push header level for nested scopes\n            tok_h = _header_depth.set(_header_depth.get() + (h,))\n            body = Node(\"div\", {\"class\": \"lt-section-body\"})\n            _append(body)\n            with _in_container(body):\n                yield\n    finally:\n        # Always reset context even on exception\n        if tok_h is not None:\n            _header_depth.reset(tok_h)\n\n\nF = TypeVar(\"F\", bound=Callable[..., Any])\n\n\n# Overloads the parameterized usage\n@overload\ndef scope_header_decorator(title: str) -> Callable[[F], F]: ...  # String title\n\n\n# Overloads the bare usage\n@overload\ndef scope_header_decorator(title: F) -> F: ...  # Bare: @scope_header_decorator\n\n\ndef scope_header_decorator(\n    title: str | F,\n) -> F | Callable[[F], F]:\n    \"\"\"\n    Decorator to wrap function in a scope_header.\n\n    Args:\n        title: String or function returning string\n\n    Examples:\n        @logtree.scope_header_decorator\n        async def process_batch():\n            ...\n\n        @logtree.scope_header_decorator(\"Handling item\")\n        def handle_item():\n            ...\n    \"\"\"\n    title_str = title if isinstance(title, str) else title.__name__\n\n    def _wrap(fn: F) -> F:\n        if inspect.iscoroutinefunction(fn):\n\n            @functools.wraps(fn)\n            async def aw(*args: Any, **kwargs: Any) -> Any:\n                # Graceful degradation: if logging is disabled, just run the function\n                if not _is_logging_enabled():\n                    return await fn(*args, **kwargs)\n\n                with scope_header(title_str):\n                    return await fn(*args, **kwargs)\n\n            return aw  # type: ignore\n        else:\n\n            @functools.wraps(fn)\n            def w(*args: Any, **kwargs: Any) -> Any:\n                # Graceful degradation: if logging is disabled, just run the function\n                if not _is_logging_enabled():\n                    return fn(*args, **kwargs)\n\n                with scope_header(title_str):\n                    return fn(*args, **kwargs)\n\n            return w  # type: ignore\n\n    if isinstance(title, str):\n        return _wrap\n    else:\n        fn = title\n        return _wrap(fn)\n\n\n@contextmanager\ndef scope_div(**attrs: Any) -> Iterator[None]:\n    \"\"\"\n    Open a <div> scope (does not change header level).\n\n    Args:\n        **attrs: HTML attributes\n\n    Example:\n        with logtree.scope_div(class_=\"grading\"):\n            logtree.log_text(\"Grade: A\")\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        yield\n        return\n\n    div = Node(\"div\", _normalize_attrs(**attrs))\n    _append(div)\n    with _in_container(div):\n        yield\n\n\n@contextmanager\ndef scope_disable() -> Iterator[None]:\n    \"\"\"\n    Disable all logging within this scope.\n\n    Example:\n        with scope_header(\"Group A\") if should_log else scope_disable():\n            logtree.log_text(\"Data here\")\n    \"\"\"\n    token = _logging_disabled.set(True)\n    try:\n        yield\n    finally:\n        _logging_disabled.reset(token)\n\n\n@contextmanager\ndef optional_enable_logging(enable: bool) -> Iterator[None]:\n    \"\"\"Context manager to optionally enable logging.\"\"\"\n    if enable:\n        yield\n    else:\n        with scope_disable():\n            yield\n\n\n@contextmanager\ndef scope_details(summary: str) -> Iterator[None]:\n    \"\"\"\n    Open a collapsible <details> scope.\n\n    Args:\n        summary: Summary text shown when collapsed\n\n    Example:\n        with logtree.scope_details(\"Click to expand\"):\n            logtree.log_text(\"Hidden content\")\n            logtree.log_text(\"More hidden content\")\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        yield\n        return\n\n    details_node = Node(\"details\", {\"class\": \"lt-details\"})\n    details_node.children.append(Node(\"summary\", {}, [html_module.escape(summary)]))\n\n    body_div = Node(\"div\", {\"class\": \"lt-details-body\"})\n    details_node.children.append(body_div)\n\n    _append(details_node)\n    with _in_container(body_div):\n        yield\n\n\n# Public API: Content\n\n\ndef log_text(text: str, *, div_class: str | None = None) -> None:\n    \"\"\"\n    Log a text paragraph.\n\n    Args:\n        text: Text to log (will be HTML-escaped)\n        div_class: If set, wrap in <div class=\"{div_class}\"> instead of <p>\n\n    Example:\n        logtree.log_text(\"Processing complete\")\n        logtree.log_text(\"Score: 0.95\", div_class=\"score\")\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        return\n\n    escaped = html_module.escape(text)\n    if div_class:\n        _append(Node(\"div\", {\"class\": div_class}, [escaped]))\n    else:\n        _append(Node(\"p\", {\"class\": \"lt-p\"}, [escaped]))\n\n\ndef log_html(html: str, *, div_class: str | None = None) -> None:\n    \"\"\"\n    Log raw HTML (not escaped).\n\n    Args:\n        html: HTML string to insert verbatim\n        div_class: If set, wrap in <div class=\"{div_class}\">\n\n    Example:\n        logtree.log_html(\"<strong>Important</strong>\")\n        logtree.log_html(conversation_html, div_class=\"conversation\")\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        return\n\n    if div_class:\n        div = Node(\"div\", {\"class\": div_class})\n        div.children.append(html)\n        _append(div)\n    else:\n        # Create a container node that holds raw HTML\n        container = Node(\"div\", {})\n        container.children.append(html)\n        _append(container)\n\n\ndef log_formatter(formatter: Formatter) -> None:\n    \"\"\"\n    Log an object that knows how to format itself as HTML.\n\n    The formatter's CSS will be automatically included in the trace output.\n    CSS is deduplicated per trace, so logging multiple objects of the same\n    type only includes the CSS once.\n\n    Args:\n        formatter: Object implementing the Formatter protocol (to_html() and get_css())\n\n    Example:\n        from tinker_cookbook.utils.logtree_formatters import ConversationFormatter\n\n        logtree.log_formatter(ConversationFormatter(messages=[...]))\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        return\n\n    # Register CSS from the formatter (deduplicated)\n    trace = _current_trace.get()\n    assert trace is not None  # _is_logging_enabled() ensures this\n    css = formatter.get_css()\n    trace._register_formatter_css(css)\n\n    # Log the HTML, with optional structured data for JSON export\n    html = formatter.to_html()\n    container = Node(\"div\", {})\n    container.children.append(html)\n    to_data = cast(\n        \"Callable[[], dict[str, Any] | None] | None\", getattr(formatter, \"to_data\", None)\n    )\n    data = to_data() if callable(to_data) else None\n    if data is not None:\n        container.data = data\n    _append(container)\n\n\ndef details(text: str, *, summary: str = \"Details\", pre: bool = True) -> None:\n    \"\"\"\n    Log collapsible details block.\n\n    Args:\n        text: Content text\n        summary: Summary text (what you see when collapsed)\n        pre: If True, use <pre> (preserves whitespace), else <div>\n\n    Example:\n        logtree.details(long_chain_of_thought, summary=\"CoT Reasoning\", pre=True)\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        return\n\n    details_node = Node(\"details\", {\"class\": \"lt-details\"})\n    details_node.children.append(Node(\"summary\", {}, [html_module.escape(summary)]))\n\n    body_node = Node(\"pre\" if pre else \"div\", {\"class\": \"lt-details-body\"})\n    body_node.children.append(html_module.escape(text))\n    details_node.children.append(body_node)\n\n    _append(details_node)\n\n\ndef header(text: str, *, level: int | None = None) -> None:\n    \"\"\"\n    Log an inline header.\n\n    Args:\n        text: Header text\n        level: Header level (1-6), or None to auto-compute from scope depth\n\n    Example:\n        logtree.header(\"Results\")\n        logtree.header(\"Subsection\", level=4)\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        return\n\n    h = level if level is not None else _next_header_level()\n    h = max(1, min(6, h))\n    _append(Node(f\"h{h}\", {\"class\": f\"lt-h{h}\"}, [html_module.escape(text)]))\n\n\n# Public API: Tables\n\n\ndef table(obj: Any, *, caption: str | None = None) -> None:\n    \"\"\"\n    Log a table from various data types.\n\n    Supports:\n    - pandas.DataFrame\n    - list[dict]\n    - list[list]\n\n    Does NOT support raw dict (use table_from_dict or table_from_dict_of_lists).\n\n    Args:\n        obj: Data object\n        caption: Optional caption text\n\n    Example:\n        logtree.table(df, caption=\"Results\")\n        logtree.table([{\"name\": \"Alice\", \"score\": 95}, {\"name\": \"Bob\", \"score\": 87}])\n    \"\"\"\n    # Graceful degradation: if logging is disabled, do nothing\n    if not _is_logging_enabled():\n        return\n\n    if isinstance(obj, dict):\n        raise TypeError(\n            \"table() does not accept dict directly. Use table_from_dict() or table_from_dict_of_lists().\"\n        )\n\n    if isinstance(obj, list) and not obj:\n        if caption:\n            _append(Node(\"div\", {\"class\": \"lt-table-caption\"}, [html_module.escape(caption)]))\n        return\n\n    # Try DataFrame — convert to records so the JSON tree gets structured\n    # Nodes (thead/tbody/tr/td) instead of a raw HTML string.\n    try:\n        import pandas as pd\n\n        if isinstance(obj, pd.DataFrame):\n            _table_from_list_of_dicts(obj.to_dict(\"records\"), caption=caption)\n            return\n    except ImportError:\n        pass\n\n    # list[dict]\n    if isinstance(obj, list) and obj and isinstance(obj[0], dict):\n        _table_from_list_of_dicts(obj, caption=caption)\n        return\n\n    # list[list]\n    if isinstance(obj, list) and obj and isinstance(obj[0], (list, tuple)):\n        _table_from_list_of_lists(obj, caption=caption)\n        return\n\n    raise TypeError(f\"table() does not support type {type(obj)}\")\n\n\ndef table_from_dict(\n    data: Mapping[Any, Any],\n    *,\n    caption: str | None = None,\n    key_header: str = \"key\",\n    value_header: str = \"value\",\n    sort_by: str | None = None,\n) -> None:\n    \"\"\"\n    Log a two-column key-value table from a dict.\n\n    Args:\n        data: Dictionary to display\n        caption: Optional caption\n        key_header: Column header for keys\n        value_header: Column header for values\n        sort_by: \"key\", \"value\", or None\n\n    Example:\n        logtree.table_from_dict({\"lr\": 0.001, \"batch_size\": 32}, caption=\"Hyperparams\")\n    \"\"\"\n    if not _is_logging_enabled():\n        return\n\n    items = list(data.items())\n    if sort_by == \"key\":\n        items.sort(key=lambda x: x[0])\n    elif sort_by == \"value\":\n        items.sort(key=lambda x: x[1])\n\n    rows = [[key_header, value_header]] + [[str(k), str(v)] for k, v in items]\n    _table_from_list_of_lists(rows, caption=caption, has_header=True)\n\n\ndef table_from_dict_of_lists(\n    columns: Mapping[str, Sequence[Any]],\n    *,\n    caption: str | None = None,\n    order: Sequence[str] | None = None,\n) -> None:\n    \"\"\"\n    Log a columnar table from dict of lists.\n\n    Args:\n        columns: Dict where keys are column names, values are column data\n        caption: Optional caption\n        order: Column order (if None, use insertion order)\n\n    Example:\n        logtree.table_from_dict_of_lists({\n            \"name\": [\"Alice\", \"Bob\"],\n            \"score\": [95, 87]\n        })\n    \"\"\"\n    if not _is_logging_enabled():\n        return\n\n    if not columns:\n        return\n\n    # Validate equal lengths\n    lengths = [len(v) for v in columns.values()]\n    if len(set(lengths)) > 1:\n        raise ValueError(\"All columns must have equal length\")\n\n    col_names = list(order) if order else list(columns.keys())\n    rows = [col_names]\n    for i in range(lengths[0]):\n        rows.append([str(columns[name][i]) for name in col_names])\n\n    _table_from_list_of_lists(rows, caption=caption, has_header=True)\n\n\ndef _table_from_list_of_dicts(data: list[dict], *, caption: str | None = None) -> None:\n    \"\"\"Helper: create table from list of dicts.\"\"\"\n    if not data:\n        return\n\n    keys = list(data[0].keys())\n    rows = [keys]\n    for item in data:\n        rows.append([str(item.get(k, \"\")) for k in keys])\n\n    _table_from_list_of_lists(rows, caption=caption, has_header=True)\n\n\ndef _table_from_list_of_lists(\n    rows: list[list[Any]], *, caption: str | None = None, has_header: bool = False\n) -> None:\n    \"\"\"Helper: create HTML table from list of lists.\"\"\"\n    if not rows:\n        return\n\n    if caption:\n        _append(Node(\"div\", {\"class\": \"lt-table-caption\"}, [html_module.escape(caption)]))\n\n    table_node = Node(\"table\", {\"class\": \"lt-table\"})\n\n    if has_header:\n        thead = Node(\"thead\")\n        tr = Node(\"tr\")\n        for cell in rows[0]:\n            tr.children.append(Node(\"th\", {}, [html_module.escape(str(cell))]))\n        thead.children.append(tr)\n        table_node.children.append(thead)\n        rows = rows[1:]\n\n    tbody = Node(\"tbody\")\n    for row in rows:\n        tr = Node(\"tr\")\n        for cell in row:\n            tr.children.append(Node(\"td\", {}, [html_module.escape(str(cell))]))\n        tbody.children.append(tr)\n    table_node.children.append(tbody)\n\n    _append(table_node)\n\n\n# Public API: Export & theming\n\n\ndef write_html_with_default_style(\n    body_html: str,\n    path: str | os.PathLike,\n    *,\n    title: str = \"Trace\",\n    theme: Theme | None = None,\n    lang: str = \"en\",\n    extra_head: str | None = None,\n) -> None:\n    \"\"\"\n    Write a complete HTML document with default styling.\n\n    Args:\n        body_html: Body HTML (with or without <body> tags)\n        path: Output file path\n        title: Document title\n        theme: Optional theme\n        lang: HTML lang attribute\n        extra_head: Extra content for <head>\n    \"\"\"\n    if theme is None:\n        theme = Theme()\n\n    # Create a temporary trace just for head generation\n    trace = Trace(title, None, False)\n\n    path_obj = Path(path)\n    path_obj.parent.mkdir(parents=True, exist_ok=True)\n\n    with open(path_obj, \"w\") as f:\n        f.write(f'<!doctype html>\\n<html lang=\"{html_module.escape(lang)}\">\\n')\n        f.write(\"<head>\\n\")\n        f.write(trace.head_html(theme=theme, title=title, extra_head=extra_head))\n        f.write(\"</head>\\n\")\n        # Ensure body tags are present\n        if \"<body\" not in body_html:\n            f.write(\"<body>\\n\")\n            f.write(body_html)\n            f.write(\"</body>\\n\")\n        else:\n            f.write(body_html)\n        f.write(\"</html>\\n\")\n\n\ndef write_trace_json(trace: Trace, path: str | os.PathLike) -> None:\n    \"\"\"\n    Write the trace structure to JSON.\n\n    Args:\n        trace: Trace object to serialize.\n        path: Output JSON file path.\n    \"\"\"\n    path_obj = Path(path)\n    path_obj.parent.mkdir(parents=True, exist_ok=True)\n    with open(path_obj, \"w\") as f:\n        json.dump(trace.to_dict(), f, indent=2)\n\n\ndef jinja_context(trace: Trace, **extra: Any) -> dict[str, Any]:\n    \"\"\"\n    Create a context dict for Jinja2 templates.\n\n    Args:\n        trace: Trace object\n        **extra: Additional context variables\n\n    Returns:\n        Dict with standard keys: title, generated_at, started_at, body_html, head_html\n    \"\"\"\n    return {\n        \"title\": trace.title,\n        \"generated_at\": datetime.now().isoformat(),\n        \"started_at\": trace.started_at.isoformat(),\n        \"body_html\": trace.body_html(),\n        \"head_html\": trace.head_html(),\n        **extra,\n    }\n\n\ndef render_with_jinja(\n    env: Any,\n    template_name: str,\n    *,\n    context: dict[str, Any],\n    write_to: str | os.PathLike | None = None,\n) -> str:\n    \"\"\"\n    Render using Jinja2 (requires jinja2 to be installed).\n\n    Args:\n        env: jinja2.Environment instance\n        template_name: Template file name\n        context: Template context\n        write_to: Optional path to write output\n\n    Returns:\n        Rendered HTML string\n    \"\"\"\n    template = env.get_template(template_name)\n    html = template.render(**context)\n\n    if write_to is not None:\n        path = Path(write_to)\n        path.parent.mkdir(parents=True, exist_ok=True)\n        with open(path, \"w\") as f:\n            f.write(html)\n\n    return html\n\n\ndef flush_trace() -> bool:\n    \"\"\"\n    Flush the current trace to the saved path even if the trace has not been exited.\n    This is useful for long-running programs where we wanna inspect some logs early.\n\n    Returns:\n        True if the trace was flushed, False otherwise.\n    \"\"\"\n    trace = _current_trace.get()\n    if trace is not None and trace.path is not None:\n        _write_trace(trace)\n        return True\n    return False\n"
  },
  {
    "path": "tinker_cookbook/utils/logtree_formatters.py",
    "content": "\"\"\"\nHTML formatters for logtree.\n\nThis module provides formatter objects that encapsulate both HTML generation\nand the CSS needed to style that HTML. Formatters implement the Formatter protocol\nfrom logtree and can be logged using `logtree.log_formatter()`.\n\"\"\"\n\nimport html\nfrom collections.abc import Sequence\nfrom dataclasses import dataclass\nfrom typing import Any\n\nfrom tinker_cookbook.renderers.base import Content, Message, message_to_jsonable\n\n\ndef _render_content_html(content: Content) -> str:\n    \"\"\"Render message content as HTML with styled parts for thinking/tool calls.\"\"\"\n    if isinstance(content, str):\n        return f'<span class=\"lt-text-part\">{html.escape(content)}</span>'\n\n    parts_html = []\n    for part in content:\n        if part[\"type\"] == \"text\":\n            parts_html.append(f'<span class=\"lt-text-part\">{html.escape(part[\"text\"])}</span>')\n        elif part[\"type\"] == \"thinking\":\n            escaped = html.escape(part[\"thinking\"])\n            parts_html.append(\n                f'<details class=\"lt-thinking-part\">'\n                f\"<summary>💭 Thinking</summary>\"\n                f\"<pre>{escaped}</pre>\"\n                f\"</details>\"\n            )\n        elif part[\"type\"] == \"tool_call\":\n            tc = part[\"tool_call\"]\n            name = html.escape(tc.function.name)\n            args = html.escape(tc.function.arguments)\n            parts_html.append(\n                f'<div class=\"lt-tool-call-part\">'\n                f'<span class=\"lt-tool-call-label\">🔧 Tool Call:</span> '\n                f\"<code>{name}({args})</code>\"\n                f\"</div>\"\n            )\n        elif part[\"type\"] == \"unparsed_tool_call\":\n            raw = html.escape(part[\"raw_text\"])\n            error = html.escape(part[\"error\"])\n            parts_html.append(\n                f'<div class=\"lt-unparsed-tool-call-part\">'\n                f'<span class=\"lt-tool-call-label\">⚠️ Unparsed Tool Call:</span> '\n                f\"<code>{raw}</code>\"\n                f'<div class=\"lt-error\">{error}</div>'\n                f\"</div>\"\n            )\n        elif part[\"type\"] == \"image\":\n            parts_html.append('<span class=\"lt-image-part\">🖼️ [Image]</span>')\n        else:\n            raise ValueError(f\"Unknown content part type: {part['type']}\")\n    return \"\\n\".join(parts_html)\n\n\n@dataclass\nclass ConversationFormatter:\n    \"\"\"\n    Formatter for conversation messages.\n\n    Renders a list of messages as a styled conversation with role-based coloring.\n    Supports structured content with thinking parts, tool calls, and text.\n    \"\"\"\n\n    messages: Sequence[Message]\n    \"\"\"List of messages, each with 'role' and 'content' keys.\"\"\"\n\n    def to_html(self) -> str:\n        \"\"\"Generate HTML for the conversation.\"\"\"\n        parts = ['<div class=\"lt-conversation\">']\n        for msg in self.messages:\n            role = html.escape(msg[\"role\"])\n            content_html = _render_content_html(msg[\"content\"])\n            parts.append(f'  <div class=\"lt-message lt-message-{role}\">')\n            parts.append(f'    <span class=\"lt-message-role\">{role}:</span>')\n            parts.append(f'    <div class=\"lt-message-content\">{content_html}</div>')\n            parts.append(\"  </div>\")\n        parts.append(\"</div>\")\n        return \"\\n\".join(parts)\n\n    def to_data(self) -> dict[str, Any]:\n        \"\"\"Return structured data for JSON export (avoids needing to parse raw HTML).\"\"\"\n        return {\n            \"type\": \"conversation\",\n            \"messages\": [message_to_jsonable(msg) for msg in self.messages],\n        }\n\n    def get_css(self) -> str:\n        \"\"\"Get CSS for conversation styling.\"\"\"\n        return CONVERSATION_CSS\n\n\n# CSS for conversation formatting\nCONVERSATION_CSS = \"\"\"\n/* Conversation formatting */\n.lt-conversation {\n    display: flex;\n    flex-direction: column;\n    gap: 0.35rem;\n    margin: 0.25rem 0;\n}\n\n.lt-message {\n    padding: 0.5rem 0.6rem;\n    border-radius: 6px;\n    border-left: 3px solid var(--lt-accent, #2563eb);\n    background: var(--lt-bg, #f9fafb);\n    line-height: 1.4;\n}\n\n.lt-message-role {\n    font-weight: 600;\n    color: var(--lt-accent, #2563eb);\n    display: inline-block;\n    min-width: 64px;\n}\n\n.lt-message-content {\n    white-space: pre-wrap;\n    word-wrap: break-word;\n}\n\n.lt-message-user {\n    background: #e3f2fd;\n    border-left-color: #1976d2;\n}\n\n.lt-message-user .lt-message-role {\n    color: #1565c0;\n}\n\n.lt-message-assistant {\n    background: #f3e5f5;\n    border-left-color: #7b1fa2;\n}\n\n.lt-message-assistant .lt-message-role {\n    color: #6a1b9a;\n}\n\n.lt-message-system {\n    background: #fff3e0;\n    border-left-color: #f57c00;\n}\n\n.lt-message-system .lt-message-role {\n    color: #e65100;\n}\n\n.lt-message-tool {\n    background: #e8f5e9;\n    border-left-color: #388e3c;\n}\n\n.lt-message-tool .lt-message-role {\n    color: #2e7d32;\n}\n\n/* Content part styling */\n.lt-text-part {\n    display: block;\n}\n\n.lt-thinking-part {\n    margin: 0.5rem 0;\n    padding: 0.5rem;\n    background: #fef3c7;\n    border: 1px solid #fbbf24;\n    border-radius: 4px;\n}\n\n.lt-thinking-part summary {\n    cursor: pointer;\n    font-weight: 500;\n    color: #92400e;\n}\n\n.lt-thinking-part pre {\n    margin: 0.5rem 0 0 0;\n    padding: 0.5rem;\n    background: #fffbeb;\n    border-radius: 4px;\n    font-size: 0.875rem;\n    overflow-x: auto;\n    white-space: pre-wrap;\n}\n\n.lt-tool-call-part {\n    margin: 0.5rem 0;\n    padding: 0.5rem;\n    background: #dbeafe;\n    border: 1px solid #3b82f6;\n    border-radius: 4px;\n}\n\n.lt-tool-call-label {\n    font-weight: 500;\n    color: #1e40af;\n}\n\n.lt-tool-call-part code {\n    display: block;\n    margin-top: 0.25rem;\n    padding: 0.25rem 0.5rem;\n    background: #eff6ff;\n    border-radius: 2px;\n    font-size: 0.875rem;\n    overflow-x: auto;\n}\n\n.lt-unparsed-tool-call-part {\n    margin: 0.5rem 0;\n    padding: 0.5rem;\n    background: #fee2e2;\n    border: 1px solid #ef4444;\n    border-radius: 4px;\n}\n\n.lt-unparsed-tool-call-part .lt-tool-call-label {\n    color: #991b1b;\n}\n\n.lt-unparsed-tool-call-part code {\n    display: block;\n    margin-top: 0.25rem;\n    padding: 0.25rem 0.5rem;\n    background: #fef2f2;\n    border-radius: 2px;\n    font-size: 0.875rem;\n    overflow-x: auto;\n}\n\n.lt-error {\n    margin-top: 0.25rem;\n    font-size: 0.875rem;\n    color: #dc2626;\n}\n\n.lt-image-part {\n    display: inline-block;\n    padding: 0.25rem 0.5rem;\n    background: #e0e7ff;\n    border-radius: 4px;\n    color: #3730a3;\n}\n\"\"\"\n"
  },
  {
    "path": "tinker_cookbook/utils/logtree_test.py",
    "content": "\"\"\"Tests for the logtree module.\"\"\"\n\nimport asyncio\nimport json\nimport os\nimport tempfile\nfrom pathlib import Path\nfrom typing import Any, cast\n\nfrom tinker_cookbook.renderers.base import Message, ToolCall, UnparsedToolCall\nfrom tinker_cookbook.utils import logtree\n\n\ndef test_basic_trace():\n    \"\"\"Test basic trace creation and HTML generation.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"test.html\"\n\n        with logtree.init_trace(\"Test Report\", path=output_path):\n            logtree.log_text(\"Hello world\")\n            with logtree.scope_header(\"Section 1\"):\n                logtree.log_text(\"Content in section 1\")\n\n        assert output_path.exists()\n        content = output_path.read_text()\n\n        # Check for expected elements\n        assert \"<title>Test Report</title>\" in content\n        assert \"<h1\" in content and \"Test Report\" in content\n        assert \"Hello world\" in content\n        assert \"Section 1\" in content\n        assert \"Content in section 1\" in content\n\n\ndef test_log_text_renders_inline_text_node():\n    \"\"\"Text-only paragraphs should render inline, without leading newline whitespace.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"inline_text.html\"\n\n        with logtree.init_trace(\"Inline Text Test\", path=output_path):\n            logtree.log_text(\"parse_success: 0\")\n\n        content = output_path.read_text()\n        assert '<p class=\"lt-p\">parse_success: 0</p>' in content\n\n\ndef test_nested_scopes():\n    \"\"\"Test nested scopes and auto header levels.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"nested.html\"\n\n        with logtree.init_trace(\"Nested Test\", path=output_path):\n            with logtree.scope_header(\"Level 1\"):\n                logtree.log_text(\"At level 1\")\n                with logtree.scope_header(\"Level 2\"):\n                    logtree.log_text(\"At level 2\")\n                    with logtree.scope_header(\"Level 3\"):\n                        logtree.log_text(\"At level 3\")\n\n        content = output_path.read_text()\n\n        # Check that we have h1 (title), h2, h3, h4\n        assert \"<h1\" in content\n        assert \"<h2\" in content\n        assert \"<h3\" in content\n        assert \"<h4\" in content\n\n\ndef test_conditional_logging():\n    \"\"\"Test scope_disable for conditional logging.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"conditional.html\"\n\n        with logtree.init_trace(\"Conditional Test\", path=output_path):\n            for i in range(5):\n                # Only log groups 0 and 2\n                with logtree.scope_header(f\"Group {i}\") if i in {0, 2} else logtree.scope_disable():\n                    logtree.log_text(f\"Content for group {i}\")\n\n        content = output_path.read_text()\n\n        # Check that groups 0 and 2 are present\n        assert \"Group 0\" in content\n        assert \"Content for group 0\" in content\n        assert \"Group 2\" in content\n        assert \"Content for group 2\" in content\n\n        # Check that groups 1, 3, 4 are not present\n        assert \"Group 1\" not in content\n        assert \"Group 3\" not in content\n        assert \"Group 4\" not in content\n\n\ndef test_table_rendering():\n    \"\"\"Test various table rendering functions.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"tables.html\"\n\n        with logtree.init_trace(\"Table Test\", path=output_path):\n            # Test table_from_dict\n            logtree.table_from_dict({\"lr\": 0.001, \"batch_size\": 32}, caption=\"Hyperparams\")\n\n            # Test table from list of dicts\n            logtree.table([{\"name\": \"Alice\", \"score\": 95}, {\"name\": \"Bob\", \"score\": 87}])\n\n            # Test table_from_dict_of_lists\n            logtree.table_from_dict_of_lists(\n                {\"name\": [\"Charlie\", \"Diana\"], \"score\": [92, 88]}, caption=\"Results\"\n            )\n\n        content = output_path.read_text()\n\n        assert \"Hyperparams\" in content\n        assert \"0.001\" in content\n        assert \"batch_size\" in content\n        assert \"Alice\" in content\n        assert \"Bob\" in content\n        assert \"Charlie\" in content\n        assert \"Results\" in content\n\n\ndef test_html_content():\n    \"\"\"Test log_html for raw HTML insertion.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"html.html\"\n\n        with logtree.init_trace(\"HTML Test\", path=output_path):\n            logtree.log_html(\"<strong>Bold text</strong>\")\n            logtree.log_html(\"<em>Italic</em>\", div_class=\"emphasis\")\n\n        content = output_path.read_text()\n\n        assert \"<strong>Bold text</strong>\" in content\n        assert \"<em>Italic</em>\" in content\n        assert 'class=\"emphasis\"' in content\n\n\ndef test_details():\n    \"\"\"Test collapsible details blocks.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"details.html\"\n\n        with logtree.init_trace(\"Details Test\", path=output_path):\n            logtree.details(\"This is a long\\nmultiline\\ntext\", summary=\"Click to expand\")\n\n        content = output_path.read_text()\n\n        assert \"<details\" in content\n        assert \"<summary\" in content\n        assert \"Click to expand\" in content\n        assert \"long\" in content and \"multiline\" in content\n\n\nasync def async_test_async_safety():\n    \"\"\"Test that logtree is async-safe with concurrent tasks.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"async.html\"\n\n        async def worker(task_id: int):\n            with logtree.scope_header(f\"Task {task_id}\"):\n                logtree.log_text(f\"Started task {task_id}\")\n                await asyncio.sleep(0.01)\n                logtree.log_text(f\"Finished task {task_id}\")\n\n        with logtree.init_trace(\"Async Test\", path=output_path):\n            await asyncio.gather(*[worker(i) for i in range(5)])\n\n        content = output_path.read_text()\n\n        # Check that all tasks are logged\n        for i in range(5):\n            assert f\"Task {i}\" in content\n            assert f\"Started task {i}\" in content\n            assert f\"Finished task {i}\" in content\n\n\ndef test_async_safety():\n    \"\"\"Wrapper to run async test.\"\"\"\n    asyncio.run(async_test_async_safety())\n\n\ndef test_scope_header_decorator():\n    \"\"\"Test the scope_header_decorator.\"\"\"\n\n    @logtree.scope_header_decorator\n    def simple_function():\n        logtree.log_text(\"Inside simple function\")\n\n    @logtree.scope_header_decorator(\"Custom Title\")\n    def custom_title_function():\n        logtree.log_text(\"Inside custom title function\")\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"decorator.html\"\n\n        with logtree.init_trace(\"Decorator Test\", path=output_path):\n            simple_function()\n            custom_title_function()\n\n        content = output_path.read_text()\n\n        assert \"simple_function\" in content\n        assert \"Inside simple function\" in content\n        assert \"Custom Title\" in content\n        assert \"Inside custom title function\" in content\n\n\nasync def async_test_scope_header_decorator():\n    \"\"\"Test scope_header_decorator with async functions.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"async_decorator.html\"\n\n        @logtree.scope_header_decorator(\"Async Work\")\n        async def async_work(value: int):\n            logtree.log_text(f\"Working on {value}\")\n            await asyncio.sleep(0.01)\n            logtree.log_text(f\"Done with {value}\")\n\n        with logtree.init_trace(\"Async Decorator Test\", path=output_path):\n            await async_work(123)\n\n        content = output_path.read_text()\n\n        assert \"Async Work\" in content\n        assert \"Working on 123\" in content\n        assert \"Done with 123\" in content\n\n\ndef test_async_decorator():\n    \"\"\"Wrapper to run async decorator test.\"\"\"\n    asyncio.run(async_test_scope_header_decorator())\n\n\ndef test_error_handling():\n    \"\"\"Test that traces are written even on error when write_on_error=True.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"error.html\"\n\n        try:\n            with logtree.init_trace(\"Error Test\", path=output_path, write_on_error=True):\n                logtree.log_text(\"Before error\")\n                raise ValueError(\"Test error\")\n        except ValueError:\n            pass\n\n        assert output_path.exists()\n        content = output_path.read_text()\n\n        assert \"Before error\" in content\n        assert \"Exception\" in content\n        assert \"ValueError\" in content\n        assert \"Test error\" in content\n\n\ndef test_no_write_without_path():\n    \"\"\"Test that no file is written when path=None.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        # Change to tmpdir to ensure no files are created\n        original_cwd = os.getcwd()\n        try:\n            os.chdir(tmpdir)\n\n            with logtree.init_trace(\"No Write Test\", path=None) as trace:\n                logtree.log_text(\"This should not be written to disk\")\n                body_html = trace.body_html()\n\n            # Check that body_html contains the content\n            assert \"This should not be written to disk\" in body_html\n\n            # Check that no HTML files were created\n            html_files = list(Path(tmpdir).glob(\"*.html\"))\n            assert len(html_files) == 0\n\n        finally:\n            os.chdir(original_cwd)\n\n\ndef test_scope_div():\n    \"\"\"Test scope_div for wrapping content without changing header level.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"div.html\"\n\n        with logtree.init_trace(\"Div Test\", path=output_path):\n            with logtree.scope_header(\"Section\"):\n                with logtree.scope_div(class_=\"custom-div\"):\n                    logtree.log_text(\"Inside custom div\")\n                    logtree.header(\"Inline header\")\n\n        content = output_path.read_text()\n\n        assert 'class=\"custom-div\"' in content\n        assert \"Inside custom div\" in content\n\n\ndef test_inline_header():\n    \"\"\"Test inline header function.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"inline_header.html\"\n\n        with logtree.init_trace(\"Inline Header Test\", path=output_path):\n            logtree.header(\"First Header\")\n            logtree.log_text(\"Some content\")\n            logtree.header(\"Second Header\", level=3)\n\n        content = output_path.read_text()\n\n        assert \"First Header\" in content\n        assert \"Second Header\" in content\n        assert \"Some content\" in content\n\n\ndef test_div_class_parameter():\n    \"\"\"Test div_class parameter for log_text and log_html.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"div_class.html\"\n\n        with logtree.init_trace(\"Div Class Test\", path=output_path):\n            logtree.log_text(\"Answer: A\", div_class=\"answer\")\n            logtree.log_text(\"Reward: 0.95\", div_class=\"reward\")\n\n        content = output_path.read_text()\n\n        assert 'class=\"answer\"' in content\n        assert 'class=\"reward\"' in content\n        assert \"Answer: A\" in content\n        assert \"Reward: 0.95\" in content\n\n\ndef test_export_helpers():\n    \"\"\"Test export helper functions.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        # Test write_html_with_default_style\n        output_path = Path(tmpdir) / \"export.html\"\n        body = \"<p>Test content</p>\"\n\n        logtree.write_html_with_default_style(body, output_path, title=\"Export Test\")\n\n        assert output_path.exists()\n        content = output_path.read_text()\n\n        assert \"<title>Export Test</title>\" in content\n        assert \"Test content\" in content\n        assert \"<!doctype html>\" in content.lower()\n\n\ndef test_write_trace_json():\n    \"\"\"Test writing trace structure to JSON.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"trace.json\"\n\n        with logtree.init_trace(\"Trace JSON Test\", path=None) as trace:\n            logtree.log_text(\"Hello JSON\")\n            with logtree.scope_header(\"Section\"):\n                logtree.log_text(\"Nested text\")\n\n        logtree.write_trace_json(trace, output_path)\n        content = json.loads(output_path.read_text())\n\n        assert content[\"title\"] == \"Trace JSON Test\"\n        assert content[\"root\"][\"tag\"] == \"body\"\n        assert \"children\" in content[\"root\"]\n        serialized = json.dumps(content)\n        assert \"Hello JSON\" in serialized\n        assert \"Section\" in serialized\n\n\ndef test_graceful_degradation():\n    \"\"\"Test that logtree functions work gracefully when no trace is active.\"\"\"\n\n    # All these should work without error when no trace is active\n    logtree.log_text(\"This should not crash\")\n    logtree.log_html(\"<p>This should not crash</p>\")\n    logtree.header(\"This should not crash\")\n    logtree.details(\"This should not crash\")\n    logtree.table([{\"a\": 1}])\n\n    with logtree.scope_header(\"This should not crash\"):\n        logtree.log_text(\"Nested content\")\n\n    with logtree.scope_div():\n        logtree.log_text(\"Div content\")\n\n    @logtree.scope_header_decorator\n    def decorated_func():\n        logtree.log_text(\"Should not crash\")\n\n    decorated_func()\n\n\nasync def async_test_graceful_degradation_decorator():\n    \"\"\"Test that decorated async functions work without trace.\"\"\"\n\n    @logtree.scope_header_decorator(\"Async Work\")\n    async def async_work():\n        logtree.log_text(\"Should not crash\")\n        await asyncio.sleep(0.001)\n\n    # Should work without error\n    await async_work()\n\n\ndef test_graceful_degradation_async():\n    \"\"\"Wrapper for async graceful degradation test.\"\"\"\n    asyncio.run(async_test_graceful_degradation_decorator())\n\n\ndef test_formatter():\n    \"\"\"Test the formatter object API.\"\"\"\n    from tinker_cookbook.utils.logtree_formatters import ConversationFormatter\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"formatter.html\"\n\n        messages: list[Message] = [\n            {\"role\": \"user\", \"content\": \"Hello\"},\n            {\"role\": \"assistant\", \"content\": \"Hi there!\"},\n            {\"role\": \"user\", \"content\": \"How are you?\"},\n        ]\n\n        with logtree.init_trace(\"Formatter Test\", path=output_path):\n            logtree.log_formatter(ConversationFormatter(messages=messages))\n\n        content = output_path.read_text()\n\n        # Check that messages are present\n        assert \"Hello\" in content\n        assert \"Hi there!\" in content\n        assert \"How are you?\" in content\n\n        # Check that CSS is included\n        assert \"lt-conversation\" in content\n        assert \"lt-message\" in content\n        assert \"lt-message-role\" in content\n\n\ndef test_formatter_html_escaping():\n    \"\"\"Test that ConversationFormatter properly escapes HTML in message content to prevent XSS.\"\"\"\n    from tinker_cookbook.utils.logtree_formatters import ConversationFormatter\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"xss.html\"\n\n        messages: list[Message] = [\n            {\"role\": \"user\", \"content\": \"What is <script>alert('xss')</script>?\"},\n            {\"role\": \"assistant\", \"content\": \"That's a <b>script</b> tag: <img onerror=alert(1)>\"},\n        ]\n\n        with logtree.init_trace(\"XSS Test\", path=output_path):\n            logtree.log_formatter(ConversationFormatter(messages=messages))\n\n        content = output_path.read_text()\n\n        # HTML tags should be escaped (< and > become &lt; and &gt;), not rendered\n        assert \"<script>\" not in content\n        assert \"&lt;script&gt;\" in content\n        # The <img> tag should also be escaped\n        assert \"<img onerror=\" not in content\n        assert \"&lt;img onerror=\" in content\n\n\ndef test_formatter_css_deduplication():\n    \"\"\"Test that formatter CSS is deduplicated per trace.\"\"\"\n    from tinker_cookbook.utils.logtree_formatters import ConversationFormatter\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"dedup.html\"\n\n        messages1: list[Message] = [{\"role\": \"user\", \"content\": \"Message 1\"}]\n        messages2: list[Message] = [{\"role\": \"assistant\", \"content\": \"Message 2\"}]\n        messages3: list[Message] = [{\"role\": \"user\", \"content\": \"Message 3\"}]\n\n        with logtree.init_trace(\"Dedup Test\", path=output_path):\n            # Log three conversation formatters\n            logtree.log_formatter(ConversationFormatter(messages=messages1))\n            logtree.log_formatter(ConversationFormatter(messages=messages2))\n            logtree.log_formatter(ConversationFormatter(messages=messages3))\n\n        content = output_path.read_text()\n\n        # CSS should appear only once\n        css_count = content.count(\".lt-conversation {\")\n        assert css_count == 1, f\"Expected CSS to appear once, but appeared {css_count} times\"\n\n        # All messages should be present\n        assert \"Message 1\" in content\n        assert \"Message 2\" in content\n        assert \"Message 3\" in content\n\n\ndef test_scope_details():\n    \"\"\"Test collapsible details scope.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"scope_details.html\"\n\n        with logtree.init_trace(\"Scope Details Test\", path=output_path):\n            logtree.log_text(\"Before details\")\n            with logtree.scope_details(\"Click to expand\"):\n                logtree.log_text(\"Hidden content 1\")\n                logtree.log_text(\"Hidden content 2\")\n            logtree.log_text(\"After details\")\n\n        content = output_path.read_text()\n\n        assert \"<details\" in content\n        assert \"<summary\" in content\n        assert \"Click to expand\" in content\n        assert \"Hidden content 1\" in content\n        assert \"Hidden content 2\" in content\n        assert \"Before details\" in content\n        assert \"After details\" in content\n\n\ndef test_scope_disable_nested():\n    \"\"\"Test that scope_disable actually disables nested logging.\"\"\"\n    with tempfile.TemporaryDirectory() as tmpdir:\n        output_path = Path(tmpdir) / \"scope_disable.html\"\n\n        with logtree.init_trace(\"Scope Disable Test\", path=output_path):\n            logtree.log_text(\"Before disabled scope\")\n\n            # This entire block should not be logged\n            with logtree.scope_disable():\n                logtree.log_text(\"This should NOT appear\")\n                with logtree.scope_header(\"Nested Header\"):\n                    logtree.log_text(\"This should also NOT appear\")\n                logtree.log_text(\"Still should NOT appear\")\n\n            logtree.log_text(\"After disabled scope\")\n\n        content = output_path.read_text()\n\n        # Check that logged content is present\n        assert \"Before disabled scope\" in content\n        assert \"After disabled scope\" in content\n\n        # Check that disabled content is NOT present\n        assert \"This should NOT appear\" not in content\n        assert \"Nested Header\" not in content\n        assert \"This should also NOT appear\" not in content\n        assert \"Still should NOT appear\" not in content\n\n\ndef test_formatter_structured_data_in_json():\n    \"\"\"Test that log_formatter attaches structured data to the JSON export.\"\"\"\n    from tinker_cookbook.utils.logtree_formatters import ConversationFormatter\n\n    messages: list[Message] = [\n        {\"role\": \"user\", \"content\": \"What is 2+2?\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": [\n                {\"type\": \"thinking\", \"thinking\": \"Let me compute...\"},\n                {\"type\": \"text\", \"text\": \"The answer is 4.\"},\n            ],\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": \"Calling a tool\",\n            \"tool_calls\": [\n                ToolCall(\n                    id=\"call_123\",\n                    function=ToolCall.FunctionBody(\n                        name=\"calculator\", arguments='{\"expression\":\"2+2\"}'\n                    ),\n                )\n            ],\n            \"unparsed_tool_calls\": [\n                UnparsedToolCall(raw_text=\"<tool_call>{bad json}</tool_call>\", error=\"Invalid JSON\")\n            ],\n            \"tool_call_id\": \"call_123\",\n            \"name\": \"calculator\",\n            \"trainable\": False,\n        },\n    ]\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        json_path = Path(tmpdir) / \"trace.json\"\n\n        with logtree.init_trace(\"Structured Data Test\", path=None) as trace:\n            logtree.log_formatter(ConversationFormatter(messages=messages))\n\n        logtree.write_trace_json(trace, json_path)\n        content = json.loads(json_path.read_text())\n\n        # Find the node with structured data\n        def find_data_nodes(node):\n            results = []\n            if isinstance(node, dict):\n                if \"data\" in node:\n                    results.append(node[\"data\"])\n                for child in node.get(\"children\", []):\n                    if isinstance(child, dict):\n                        results.extend(find_data_nodes(child))\n            return results\n\n        data_nodes = find_data_nodes(content[\"root\"])\n        assert len(data_nodes) == 1, f\"Expected 1 data node, got {len(data_nodes)}\"\n\n        data = data_nodes[0]\n        assert data[\"type\"] == \"conversation\"\n        assert len(data[\"messages\"]) == 3\n        assert data[\"messages\"][0][\"role\"] == \"user\"\n        assert data[\"messages\"][0][\"content\"] == \"What is 2+2?\"\n        assert data[\"messages\"][1][\"role\"] == \"assistant\"\n        assert data[\"messages\"][1][\"content\"][0][\"type\"] == \"thinking\"\n        assert data[\"messages\"][1][\"content\"][0][\"thinking\"] == \"Let me compute...\"\n        assert data[\"messages\"][1][\"content\"][1][\"type\"] == \"text\"\n        assert data[\"messages\"][1][\"content\"][1][\"text\"] == \"The answer is 4.\"\n        assert data[\"messages\"][2][\"tool_calls\"] == [\n            {\n                \"type\": \"function\",\n                \"id\": \"call_123\",\n                \"function\": {\"name\": \"calculator\", \"arguments\": '{\"expression\":\"2+2\"}'},\n            }\n        ]\n        assert data[\"messages\"][2][\"unparsed_tool_calls\"] == [\n            {\"raw_text\": \"<tool_call>{bad json}</tool_call>\", \"error\": \"Invalid JSON\"}\n        ]\n        assert data[\"messages\"][2][\"tool_call_id\"] == \"call_123\"\n        assert data[\"messages\"][2][\"name\"] == \"calculator\"\n        assert data[\"messages\"][2][\"trainable\"] is False\n\n        # Nodes with data should NOT have raw HTML string children in JSON\n        def find_nodes_with_data(node):\n            results = []\n            if isinstance(node, dict):\n                if \"data\" in node:\n                    results.append(node)\n                for child in node.get(\"children\", []):\n                    if isinstance(child, dict):\n                        results.extend(find_nodes_with_data(child))\n            return results\n\n        for node in find_nodes_with_data(content[\"root\"]):\n            for child in node.get(\"children\", []):\n                assert not isinstance(child, str), (\n                    f\"Node with data should not have string children in JSON, got: {child[:80]}\"\n                )\n\n\ndef test_log_formatter_without_to_data_still_works():\n    \"\"\"Custom formatters that predate to_data should still log successfully.\"\"\"\n\n    class LegacyFormatter:\n        def to_html(self) -> str:\n            return \"<div>legacy</div>\"\n\n        def get_css(self) -> str:\n            return \"\"\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        json_path = Path(tmpdir) / \"trace.json\"\n\n        with logtree.init_trace(\"Legacy Formatter Test\", path=None) as trace:\n            logtree.log_formatter(cast(Any, LegacyFormatter()))\n\n        logtree.write_trace_json(trace, json_path)\n        content = json.loads(json_path.read_text())\n        serialized = json.dumps(content)\n        assert \"legacy\" in serialized\n\n\ndef test_dataframe_table_produces_structured_nodes():\n    \"\"\"Test that DataFrame tables produce structured Nodes, not raw HTML strings.\"\"\"\n    import pandas as pd\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        json_path = Path(tmpdir) / \"trace.json\"\n\n        df = pd.DataFrame({\"name\": [\"Alice\", \"Bob\"], \"score\": [95, 87]})\n\n        with logtree.init_trace(\"DataFrame Test\", path=None) as trace:\n            logtree.table(df, caption=\"Results\")\n\n        logtree.write_trace_json(trace, json_path)\n        content = json.loads(json_path.read_text())\n\n        # Walk the tree: every child should be either a dict (Node) or a plain\n        # text string that does NOT contain HTML tags.  Raw HTML from df.to_html()\n        # would include \"<table\" or \"<tr\".\n        serialized = json.dumps(content)\n        assert \"Alice\" in serialized\n        assert \"Bob\" in serialized\n\n        def check_no_raw_html_tables(node):\n            if isinstance(node, str):\n                assert \"<table\" not in node, f\"Found raw HTML table string: {node[:100]}\"\n                assert \"<tr\" not in node, f\"Found raw HTML tr string: {node[:100]}\"\n            elif isinstance(node, dict):\n                for child in node.get(\"children\", []):\n                    check_no_raw_html_tables(child)\n\n        check_no_raw_html_tables(content[\"root\"])\n\n\nif __name__ == \"__main__\":\n    # Run tests\n    test_basic_trace()\n    test_nested_scopes()\n    test_conditional_logging()\n    test_table_rendering()\n    test_html_content()\n    test_details()\n    test_async_safety()\n    test_scope_header_decorator()\n    test_async_decorator()\n    test_error_handling()\n    test_no_write_without_path()\n    test_scope_div()\n    test_inline_header()\n    test_div_class_parameter()\n    test_export_helpers()\n    test_graceful_degradation()\n    test_graceful_degradation_async()\n    test_formatter()\n    test_formatter_css_deduplication()\n    test_scope_details()\n    test_scope_disable_nested()\n    test_formatter_structured_data_in_json()\n    test_dataframe_table_produces_structured_nodes()\n\n    print(\"All tests passed!\")\n"
  },
  {
    "path": "tinker_cookbook/utils/lr_scheduling.py",
    "content": "import logging\nimport math\nfrom typing import Literal\n\nfrom tinker_cookbook.exceptions import ConfigurationError\n\nlogger = logging.getLogger(__name__)\n\n\nLRSchedule = Literal[\"linear\", \"cosine\", \"constant\"]\n\n\ndef compute_schedule_lr_multiplier(lr_schedule: LRSchedule, step: int, total_steps: int) -> float:\n    \"\"\"\n    What factor to multiply the base LR by due to the LR schedule\n    \"\"\"\n    if lr_schedule == \"linear\":\n        return 1 - step / total_steps\n    elif lr_schedule == \"cosine\":\n        return 0.5 * (1 + math.cos(math.pi * step / total_steps))\n    elif lr_schedule == \"constant\":\n        return 1\n    else:\n        raise ConfigurationError(f\"Unknown learning rate schedule: {lr_schedule}\")\n"
  },
  {
    "path": "tinker_cookbook/utils/misc_utils.py",
    "content": "\"\"\"\nSmall utilities requiring only basic python libraries.\n\"\"\"\n\nimport importlib\nimport logging\nimport time\nfrom collections.abc import Sequence\nfrom contextlib import contextmanager\nfrom typing import Any, TypeVar, cast\n\nimport numpy as np\n\nlogger = logging.getLogger(__name__)\n\nT = TypeVar(\"T\")\n\n\n@contextmanager\ndef timed(key: str, metrics: dict[str, Any]):\n    logger.info(f\"Starting {key}\")\n    tstart = time.time()\n    yield\n    logger.info(f\"{key} took {time.time() - tstart:.2f} seconds\")\n    metrics[f\"time/{key}\"] = time.time() - tstart\n\n\nsafezip = cast(type[zip], lambda *args, **kwargs: zip(*args, **kwargs, strict=True))\n\n\ndef dict_mean(list_of_dicts: list[dict[str, float | int]]) -> dict[str, float]:\n    key2values = {}\n    for d in list_of_dicts:\n        for k, v in d.items():\n            key2values.setdefault(k, []).append(v)\n    return {k: float(np.mean(values)) for k, values in key2values.items()}\n\n\ndef all_same(xs: list[Any]) -> bool:\n    return all(x == xs[0] for x in xs)\n\n\ndef lookup_func(path_to_func: str, default_module: str | None = None):\n    \"\"\"\n    path.to.module:func_name or func_name (assumes default_module)\n    \"\"\"\n    colon_count = path_to_func.count(\":\")\n    if colon_count == 0 and default_module is not None:\n        module_name = default_module\n        func_name = path_to_func\n    elif colon_count == 1:\n        module_name, func_name = path_to_func.rsplit(\":\", 1)\n    else:\n        raise ValueError(f\"Invalid path: {path_to_func}\")\n    module = importlib.import_module(module_name)\n    return getattr(module, func_name)\n\n\ndef split_list(lst: Sequence[T], num_splits: int) -> list[list[T]]:\n    \"\"\"\n    Split a sequence into a list of lists, where the sizes are as equal as possible,\n    and the long and short lists are as uniformly distributed as possible.\n\n    Args:\n        lst: The sequence to split\n        num_splits: Number of sublists to create\n\n    Returns:\n        A list of sublists with sizes differing by at most 1\n\n    Raises:\n        ValueError: If num_splits > len(lst) or num_splits <= 0\n\n    Examples:\n        >>> split_list([1, 2, 3, 4, 5], 2)\n        [[1, 2, 3], [4, 5]]\n        >>> split_list([1, 2, 3, 4, 5], 3)\n        [[1, 2], [3, 4], [5]]\n    \"\"\"\n    if num_splits <= 0:\n        raise ValueError(f\"num_splits must be positive, got {num_splits}\")\n    if num_splits > len(lst):\n        raise ValueError(f\"Cannot split list of length {len(lst)} into {num_splits} parts\")\n\n    edges = np.linspace(0, len(lst), num_splits + 1).astype(int)\n    return [list(lst[edges[i] : edges[i + 1]]) for i in range(num_splits)]\n\n\ndef concat_lists(list_of_lists: list[list[Any]]) -> list[Any]:\n    return [item for sublist in list_of_lists for item in sublist]\n\n\ndef not_none(x: T | None) -> T:\n    assert x is not None, f\"{x=} must not be None\"\n    return x\n"
  },
  {
    "path": "tinker_cookbook/utils/ml_log.py",
    "content": "\"\"\"Simplified logging utilities for tinker-cookbook.\"\"\"\n\nimport json\nimport logging\nimport os\nimport shlex\nimport sys\nfrom abc import ABC, abstractmethod\nfrom contextlib import contextmanager\nfrom dataclasses import asdict, is_dataclass\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import Any\n\nimport chz\nfrom rich.console import Console\nfrom rich.table import Table\n\nfrom tinker_cookbook.exceptions import ConfigurationError\nfrom tinker_cookbook.utils.code_state import code_state\n\nlogger = logging.getLogger(__name__)\n\n# Check WandB availability\n_wandb_available = False\ntry:\n    import wandb\n\n    _wandb_available = True\nexcept ImportError:\n    wandb = None\n\n# Check Neptune availability\n_neptune_available = False\ntry:\n    from neptune_scale import Run as NeptuneRun\n\n    _neptune_available = True\nexcept ImportError:\n    NeptuneRun = None\n\n# Check Trackio availability\n_trackio_available = False\ntry:\n    import trackio\n\n    _trackio_available = True\nexcept ImportError:\n    trackio = None\n\n\ndef dump_config(config: Any) -> Any:\n    \"\"\"Convert configuration object to JSON-serializable format.\"\"\"\n    if hasattr(config, \"to_dict\"):\n        return config.to_dict()\n    elif chz.is_chz(config):\n        # Recursively dump values to handle nested non-serializable fields\n        return {k: dump_config(v) for k, v in chz.asdict(config).items()}\n    elif is_dataclass(config) and not isinstance(config, type):\n        # Recursively dump values to handle nested non-serializable fields\n        return {k: dump_config(v) for k, v in asdict(config).items()}\n    elif isinstance(config, dict):\n        return {k: dump_config(v) for k, v in config.items()}\n    elif isinstance(config, (list, tuple)):\n        return [dump_config(item) for item in config]\n    elif isinstance(config, Enum):\n        return config.value\n    elif hasattr(config, \"__dict__\"):\n        # Handle simple objects with __dict__\n        return {\n            k: dump_config(v) for k, v in config.__dict__.items() if not k.startswith((\"_\", \"X_\"))\n        }\n    elif callable(config):\n        # For callables, return their string representation\n        return f\"{config.__module__}.{config.__name__}\"\n    else:\n        return config\n\n\nclass Logger(ABC):\n    \"\"\"Abstract base class for loggers.\"\"\"\n\n    @abstractmethod\n    def log_hparams(self, config: Any) -> None:\n        \"\"\"Log hyperparameters/configuration.\"\"\"\n        pass\n\n    @abstractmethod\n    def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:\n        \"\"\"Log metrics dictionary with optional step number.\"\"\"\n        pass\n\n    def log_long_text(self, key: str, text: str) -> None:\n        \"\"\"Log long text content (optional to implement).\"\"\"\n        pass\n\n    def close(self) -> None:\n        \"\"\"Cleanup when done (optional to implement).\"\"\"\n        pass\n\n    def sync(self) -> None:\n        \"\"\"Force synchronization (optional to implement).\"\"\"\n        pass\n\n    def get_logger_url(self) -> str | None:\n        \"\"\"Get a permalink to view this logger's results.\"\"\"\n        return None\n\n\nclass _PermissiveJSONEncoder(json.JSONEncoder):\n    \"\"\"A JSON encoder that handles non-encodable objects by converting them to their type string.\"\"\"\n\n    def default(self, o: Any) -> Any:\n        try:\n            return super().default(o)\n        except TypeError:\n            # Only handle the truly non-encodable objects\n            return str(type(o))\n\n\nclass JsonLogger(Logger):\n    \"\"\"Logger that writes metrics to a JSONL file.\"\"\"\n\n    def __init__(self, log_dir: str | Path):\n        self.log_dir = Path(log_dir).expanduser()\n        self.log_dir.mkdir(parents=True, exist_ok=True)\n        self.metrics_file = self.log_dir / \"metrics.jsonl\"\n        self._logged_hparams = False\n\n    def log_hparams(self, config: Any) -> None:\n        \"\"\"Log hyperparameters to a separate config.json file.\"\"\"\n        if not self._logged_hparams:\n            config_dict = dump_config(config)\n            config_file = self.log_dir / \"config.json\"\n            with open(config_file, \"w\") as f:\n                json.dump(config_dict, f, indent=2, cls=_PermissiveJSONEncoder)\n            diff_file = code_state()\n            with open(self.log_dir / \"code.diff\", \"w\") as f:\n                f.write(diff_file)\n            self._logged_hparams = True\n\n    def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:\n        \"\"\"Append metrics to JSONL file.\"\"\"\n        log_entry = {\"step\": step} if step is not None else {}\n        log_entry.update(metrics)\n\n        with open(self.metrics_file, \"a\") as f:\n            f.write(json.dumps(log_entry) + \"\\n\")\n            logger.info(\"Wrote metrics to %s\", self.metrics_file)\n\n\nclass PrettyPrintLogger(Logger):\n    \"\"\"Logger that displays metrics in a formatted table in the console.\"\"\"\n\n    def __init__(self):\n        self.console = Console()\n        self._last_step = None\n\n    def log_hparams(self, config: Any) -> None:\n        \"\"\"Print configuration summary.\"\"\"\n        config_dict = dump_config(config)\n        with _rich_console_use_logger(self.console):\n            self.console.print(\"[bold cyan]Configuration:[/bold cyan]\")\n            for key, value in config_dict.items():\n                self.console.print(f\"  {key}: {_maybe_truncate_repr(value)}\")\n\n    def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:\n        \"\"\"Display metrics in console.\"\"\"\n        if not metrics:\n            return\n\n        table = Table(show_header=True, header_style=\"bold magenta\")\n        table.add_column(\"Metric\", style=\"cyan\", width=30)\n        table.add_column(\"Value\", style=\"green\")\n\n        if step is not None:\n            table.title = f\"Step {step}\"\n\n        for key, value in sorted(metrics.items()):\n            if isinstance(value, float):\n                value_str = f\"{value:.6f}\"\n            else:\n                value_str = str(value)\n            table.add_row(key, value_str)\n\n        with _rich_console_use_logger(self.console):\n            self.console.print(table)\n\n\ndef _maybe_truncate_repr(value: Any) -> str:\n    repr_value = repr(value)\n    if len(repr_value) > 256:\n        return repr_value[:128] + \" ... \" + repr_value[-128:]\n    return repr_value\n\n\n@contextmanager\ndef _rich_console_use_logger(console: Console):\n    with console.capture() as capture:\n        yield\n    logger.info(\"\\n\" + capture.get().rstrip())\n    # ^^^ add a leading newline so things like table formatting work properly\n\n\nclass WandbLogger(Logger):\n    \"\"\"Logger for Weights & Biases.\"\"\"\n\n    def __init__(\n        self,\n        project: str | None = None,\n        config: Any | None = None,\n        log_dir: str | Path | None = None,\n        wandb_name: str | None = None,\n    ):\n        if not _wandb_available:\n            raise ImportError(\n                \"wandb is not installed. Please install it with: \"\n                \"pip install wandb (or uv add wandb)\"\n            )\n\n        if not os.environ.get(\"WANDB_API_KEY\"):\n            raise ConfigurationError(\"WANDB_API_KEY environment variable not set\")\n\n        # Initialize wandb run\n        assert wandb is not None  # For type checker\n        self.run = wandb.init(\n            project=project,\n            config=dump_config(config) if config else None,\n            dir=str(log_dir) if log_dir else None,\n            name=wandb_name,\n        )\n\n    def log_hparams(self, config: Any) -> None:\n        \"\"\"Log hyperparameters to wandb.\"\"\"\n        if self.run and wandb is not None:\n            wandb.config.update(dump_config(config))\n\n    def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:\n        \"\"\"Log metrics to wandb.\"\"\"\n        if self.run and wandb is not None:\n            wandb.log(metrics, step=step)\n            logger.info(\"Logging to: %s\", self.run.url)\n\n    def close(self) -> None:\n        \"\"\"Close wandb run.\"\"\"\n        if self.run and wandb is not None:\n            wandb.finish()\n\n    def get_logger_url(self) -> str | None:\n        \"\"\"Get the URL of the wandb run.\"\"\"\n        if self.run and wandb is not None:\n            return self.run.url\n        return None\n\n\nclass NeptuneLogger(Logger):\n    \"\"\"Logger for Neptune.\"\"\"\n\n    def __init__(\n        self,\n        project: str | None = None,\n        config: Any | None = None,\n        log_dir: str | Path | None = None,\n        neptune_name: str | None = None,\n    ):\n        if not _neptune_available:\n            raise ImportError(\n                \"neptune-scale is not installed. Please install it with: \"\n                \"pip install neptune-scale (or uv add neptune-scale)\"\n            )\n\n        if not os.environ.get(\"NEPTUNE_API_TOKEN\"):\n            raise ConfigurationError(\"NEPTUNE_API_TOKEN environment variable not set\")\n\n        # Initialize neptune run\n        assert NeptuneRun is not None  # For type checker\n        self.run = NeptuneRun(\n            project=project,\n            log_directory=str(log_dir) if log_dir else None,\n            experiment_name=neptune_name,\n        )\n        self.run.log_configs(dump_config(config) if config else None, flatten=True)\n\n    def log_hparams(self, config: Any) -> None:\n        \"\"\"Log hyperparameters to neptune.\"\"\"\n        if self.run and NeptuneRun is not None:\n            self.run.log_configs(dump_config(config) if config else None, flatten=True)\n\n    def log_metrics(\n        self,\n        metrics: dict[str, Any],\n        step: float | int | None = None,\n    ) -> None:\n        \"\"\"Log metrics to neptune.\"\"\"\n        if self.run and NeptuneRun is not None:\n            assert step is not None, \"step is required to be int or float for Neptune logging.\"\n            self.run.log_metrics(metrics, step=step)\n            logger.info(\"Logging to: %s\", self.run.get_run_url())\n\n    def close(self) -> None:\n        \"\"\"Close neptune run.\"\"\"\n        if self.run and NeptuneRun is not None:\n            self.run.close()\n\n\nclass TrackioLogger(Logger):\n    \"\"\"Logger for Trackio.\"\"\"\n\n    def __init__(\n        self,\n        project: str | None = None,\n        config: Any | None = None,\n        log_dir: str | Path | None = None,\n        trackio_name: str | None = None,\n    ):\n        if not _trackio_available:\n            raise ImportError(\n                \"trackio is not installed. Please install it with: \"\n                \"pip install trackio (or uv add trackio)\"\n            )\n\n        assert trackio is not None\n        self.run = trackio.init(\n            project=project or \"default\",\n            name=trackio_name,\n            config=dump_config(config) if config else None,\n        )\n\n    def log_hparams(self, config: Any) -> None:\n        \"\"\"Log hyperparameters to trackio.\"\"\"\n        if self.run and trackio is not None:\n            pass\n\n    def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:\n        \"\"\"Log metrics to trackio.\"\"\"\n        if self.run and trackio is not None:\n            trackio.log(metrics, step=step)\n            logger.info(\"Logged metrics to Trackio project: %s\", self.run.project)\n\n    def close(self) -> None:\n        \"\"\"Close trackio run.\"\"\"\n        if self.run and trackio is not None:\n            trackio.finish()\n\n\nclass MultiplexLogger(Logger):\n    \"\"\"Logger that forwards operations to multiple child loggers.\"\"\"\n\n    def __init__(self, loggers: list[Logger]):\n        self.loggers = loggers\n\n    def log_hparams(self, config: Any) -> None:\n        \"\"\"Forward log_hparams to all child loggers.\"\"\"\n        for logger in self.loggers:\n            logger.log_hparams(config)\n\n    def log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None:\n        \"\"\"Forward log_metrics to all child loggers.\"\"\"\n        for logger in self.loggers:\n            logger.log_metrics(metrics, step)\n\n    def log_long_text(self, key: str, text: str) -> None:\n        \"\"\"Forward log_long_text to all child loggers.\"\"\"\n        for logger in self.loggers:\n            if hasattr(logger, \"log_long_text\"):\n                logger.log_long_text(key, text)\n\n    def close(self) -> None:\n        \"\"\"Close all child loggers.\"\"\"\n        for logger in self.loggers:\n            if hasattr(logger, \"close\"):\n                logger.close()\n\n    def sync(self) -> None:\n        \"\"\"Sync all child loggers.\"\"\"\n        for logger in self.loggers:\n            if hasattr(logger, \"sync\"):\n                logger.sync()\n\n    def get_logger_url(self) -> str | None:\n        \"\"\"Get the first URL returned by the child loggers.\"\"\"\n        for logger in self.loggers:\n            if url := logger.get_logger_url():\n                return url\n        return None\n\n\ndef setup_logging(\n    log_dir: str,\n    wandb_project: str | None = None,\n    wandb_name: str | None = None,\n    config: Any | None = None,\n    do_configure_logging_module: bool = True,\n) -> Logger:\n    \"\"\"\n    Set up logging infrastructure with multiple backends.\n\n    Args:\n        log_dir: Directory for logs\n        wandb_project: W&B project name (if None, W&B logging is skipped)\n        wandb_name: W&B run name\n        config: Configuration object to log\n        do_configure_logging_module: Whether to configure the logging module\n\n    Returns:\n        MultiplexLogger that combines all enabled loggers\n    \"\"\"\n    # Create log directory\n    log_dir_path = Path(log_dir).expanduser()\n    log_dir_path.mkdir(parents=True, exist_ok=True)\n\n    # Initialize loggers\n    loggers = []\n\n    # Always add JSON logger\n    loggers.append(JsonLogger(log_dir_path))\n\n    # Always add pretty print logger\n    loggers.append(PrettyPrintLogger())\n\n    # Add W&B logger if available and configured\n    if wandb_project:\n        if not _wandb_available:\n            print(\"WARNING: wandb is not installed. Skipping W&B logging.\")\n        elif not os.environ.get(\"WANDB_API_KEY\"):\n            print(\"WARNING: WANDB_API_KEY environment variable not set. Skipping W&B logging. \")\n        else:\n            loggers.append(\n                WandbLogger(\n                    project=wandb_project,\n                    config=config,\n                    log_dir=log_dir_path,\n                    wandb_name=wandb_name,\n                )\n            )\n\n    # Add Neptune logger if available and configured\n    # - MZ 10/8/25: Hack, but before doing bigger logger-agnostic refactor,\n    #   allow Neptune to use the same W&B project and name.\n    # - Project_name should be `workspace-name/project-name`.\n    # - Also allow logging to both W&B and Neptune\n    if wandb_project and _neptune_available:\n        # if not _neptune_available:\n        #     print(\"WARNING: neptune-scale is not installed. Skipping Neptune logging.\")\n        if not os.environ.get(\"NEPTUNE_API_TOKEN\"):\n            print(\n                \"WARNING: NEPTUNE_API_TOKEN environment variable not set. \"\n                \"Skipping Neptune logging. \"\n            )\n        else:\n            loggers.append(\n                NeptuneLogger(\n                    project=wandb_project,\n                    config=config,\n                    log_dir=log_dir_path,\n                    neptune_name=wandb_name,\n                )\n            )\n\n    if wandb_project and _trackio_available:\n        loggers.append(\n            TrackioLogger(\n                project=wandb_project,\n                config=config,\n                log_dir=log_dir_path,\n                trackio_name=wandb_name,\n            )\n        )\n        print(f\"Trackio logging enabled for project: {wandb_project}\")\n\n    # Create multiplex logger\n    ml_logger = MultiplexLogger(loggers)\n\n    # Log initial configuration\n    if config is not None:\n        ml_logger.log_hparams(config)\n\n    if do_configure_logging_module:\n        configure_logging_module(str(log_dir_path / \"logs.log\"))\n\n    logger.info(f\"Logging to: {log_dir_path}\")\n    return ml_logger\n\n\ndef _get_command_line_invocation() -> str:\n    \"\"\"Return the current command line in a shell-safe form.\"\"\"\n    if not sys.argv:\n        return \"<empty sys.argv>\"\n    return shlex.join(sys.argv)\n\n\ndef configure_logging_module(path: str, level: int = logging.INFO) -> logging.Logger:\n    \"\"\"Configure logging to console (color) and file (plain), forcing override of prior config.\"\"\"\n    # ANSI escape codes for colors\n    COLORS = {\n        \"DEBUG\": \"\\033[94m\",  # Blue\n        \"INFO\": \"\\033[92m\",  # Green\n        \"WARNING\": \"\\033[93m\",  # Yellow\n        \"ERROR\": \"\\033[91m\",  # Red\n        \"CRITICAL\": \"\\033[95m\",  # Magenta\n    }\n    RESET = \"\\033[0m\"\n\n    class ColorFormatter(logging.Formatter):\n        \"\"\"Colorized log formatter for console output that doesn't mutate record.levelname.\"\"\"\n\n        def format(self, record: logging.LogRecord) -> str:\n            color = COLORS.get(record.levelname, \"\")\n            # add a separate attribute for the colored level name\n            record.levelname_colored = f\"{color}{record.levelname}{RESET}\"\n            return super().format(record)\n\n    # Console handler with colors\n    console_handler = logging.StreamHandler()\n    console_handler.setFormatter(\n        ColorFormatter(\"%(name)s:%(lineno)d [%(levelname_colored)s] %(message)s\")\n    )\n\n    # File handler without colors\n    file_handler = logging.FileHandler(path, mode=\"a\", encoding=\"utf-8\")\n    file_handler.setFormatter(logging.Formatter(\"%(name)s:%(lineno)d [%(levelname)s] %(message)s\"))\n\n    # Force override like basicConfig(..., force=True)\n    root = logging.getLogger()\n    root.setLevel(level)\n    for handler in root.handlers[:]:\n        root.removeHandler(handler)\n        handler.close()\n    root.addHandler(console_handler)\n    root.addHandler(file_handler)\n    root.info(\"Command line invocation: %s\", _get_command_line_invocation())\n\n    return root\n"
  },
  {
    "path": "tinker_cookbook/utils/ml_log_test.py",
    "content": "import logging\nimport shlex\nimport sys\nfrom unittest.mock import patch\n\nfrom .ml_log import configure_logging_module\n\n\ndef _flush_root_handlers() -> None:\n    for handler in logging.getLogger().handlers:\n        handler.flush()\n\n\ndef test_configure_logging_module_logs_invocation_and_appends(tmp_path):\n    log_path = tmp_path / \"logs.log\"\n\n    argv_first = [\"python\", \"train.py\", \"--log-path\", str(tmp_path), \"--run-name\", \"first run\"]\n    with patch.object(sys, \"argv\", argv_first):\n        root_logger = configure_logging_module(str(log_path))\n        root_logger.info(\"first message\")\n        _flush_root_handlers()\n\n    first_contents = log_path.read_text()\n    first_invocation = shlex.join(argv_first)\n    assert f\"Command line invocation: {first_invocation}\" in first_contents\n    assert \"first message\" in first_contents\n    assert first_contents.index(first_invocation) < first_contents.index(\"first message\")\n\n    argv_second = [\"python\", \"train.py\", \"--resume\", \"--run-name\", \"second run\"]\n    with patch.object(sys, \"argv\", argv_second):\n        root_logger = configure_logging_module(str(log_path))\n        root_logger.info(\"second message\")\n        _flush_root_handlers()\n\n    final_contents = log_path.read_text()\n    second_invocation = shlex.join(argv_second)\n    assert \"first message\" in final_contents\n    assert \"second message\" in final_contents\n    assert f\"Command line invocation: {second_invocation}\" in final_contents\n    assert final_contents.count(\"Command line invocation:\") == 2\n    assert final_contents.index(\"first message\") < final_contents.index(second_invocation)\n    assert final_contents.index(second_invocation) < final_contents.index(\"second message\")\n"
  },
  {
    "path": "tinker_cookbook/utils/trace.py",
    "content": "import argparse\nimport asyncio\nimport atexit\nimport contextlib\nimport datetime\nimport functools\nimport inspect\nimport json\nimport logging\nimport queue\nimport threading\nimport time\nfrom collections import defaultdict\nfrom collections.abc import Callable, Generator\nfrom contextvars import ContextVar\nfrom dataclasses import dataclass, field\nfrom enum import StrEnum\nfrom io import TextIOWrapper\nfrom pathlib import Path\nfrom typing import Any\n\nlogger = logging.getLogger(__name__)\n\n\nclass EventType(StrEnum):\n    \"\"\"Chrome Trace/Perfetto Event type\"\"\"\n\n    BEGIN = \"B\"\n    END = \"E\"\n    METADATA = \"M\"\n\n\n@dataclass\nclass TraceEvent:\n    \"\"\"Represents a trace event in Chrome Trace/Perfetto Format\"\"\"\n\n    name: str\n    ph: EventType\n    pid: int\n    tid: int\n    ts: float\n    args: dict[str, Any] = field(default_factory=dict)\n    cat: str | None = None\n\n    def to_dict(self) -> dict[str, Any]:\n        \"\"\"Convert the TraceEvent to a dictionary for JSON serialization.\"\"\"\n        result = {\n            \"name\": self.name,\n            \"ph\": self.ph.value,\n            \"pid\": self.pid,\n            \"tid\": self.tid,\n            \"ts\": self.ts,\n            \"args\": self.args,\n        }\n        if self.cat is not None:\n            result[\"cat\"] = self.cat\n        return result\n\n\n@dataclass\nclass ScopeContext:\n    # Additional attributes to log into the trace for this function call\n    attributes: dict[str, Any] = field(default_factory=dict)\n\n\n# Context variable to track the current coroutine's trace context\ntrace_context: ContextVar[ScopeContext | None] = ContextVar(\"trace_context\", default=None)\n\n\n@dataclass\nclass SpanRecord:\n    \"\"\"A recorded span within an iteration window.\n\n    We store two sets of timestamps:\n    - ``start_time`` / ``end_time``: from ``time.perf_counter()``, used for duration\n      calculations (aggregation metrics). High resolution but process-local — values\n      cannot be compared across processes.\n    - ``wall_start`` / ``wall_end``: from ``time.time()``, used for positioning spans\n      on Gantt charts. Synchronized across processes on the same machine, so spans\n      from multiprocess workers (ProcessPoolExecutor, Ray) can be placed on a shared\n      timeline without clock alignment.\n    \"\"\"\n\n    name: str\n    start_time: float  # seconds (perf_counter, process-local)\n    end_time: float  # seconds (perf_counter, process-local)\n    wall_start: float  # seconds since epoch (time.time, cross-process comparable)\n    wall_end: float  # seconds since epoch (time.time, cross-process comparable)\n\n\nclass IterationWindow:\n    \"\"\"Collects span records during a single training iteration for aggregation.\n\n    Use with :func:`trace_iteration` to automatically capture all ``@scope`` and\n    ``scope_span`` timings within a training iteration. After the block exits,\n    call :meth:`get_timing_metrics` for a flat dict of timing metrics ready to log.\n\n    Example — GRPO training loop::\n\n        for i_batch in range(n_batches):\n            with trace_iteration(step=i_batch) as window:\n                # All @scope-decorated calls inside this block are recorded\n                await run_evals(sampling_client, ...)\n                trajectory_groups = await gather_rollouts(sampling_client, ...)\n                await train_step(training_client, trajectory_groups, ...)\n\n            # Aggregated metrics: time/total, time/run_evals, time/sample_async:total, ...\n            metrics.update(window.get_timing_metrics())\n\n            # Persist per-span data for post-hoc analysis\n            window.write_spans_jsonl(log_path / \"timing_spans.jsonl\", step=i_batch)\n\n            # Optional: save a Gantt chart every K steps\n            if i_batch % 10 == 0:\n                save_gantt_chart_html(window, i_batch, log_path / f\"gantt_{i_batch}.html\")\n\n            ml_logger.log_metrics(metrics, step=i_batch)\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.spans: list[SpanRecord] = []\n        self._lock = threading.Lock()\n        self._total_time: float | None = None\n\n    def record_span(self, name: str, start_time: float, end_time: float) -> None:\n        with self._lock:\n            self.spans.append(\n                SpanRecord(\n                    name=name,\n                    start_time=start_time,\n                    end_time=end_time,\n                    wall_start=time.time() - (time.perf_counter() - start_time),\n                    wall_end=time.time() - (time.perf_counter() - end_time),\n                )\n            )\n\n    def aggregate(self) -> dict[str, float]:\n        \"\"\"Aggregate collected spans into a flat timing dict.\"\"\"\n        with self._lock:\n            spans = list(self.spans)\n\n        if not spans:\n            return {}\n\n        # Group durations by name\n        durations_by_name: dict[str, list[float]] = defaultdict(list)\n        for span in spans:\n            durations_by_name[span.name].append(span.end_time - span.start_time)\n\n        metrics: dict[str, float] = {}\n        for name, durations in durations_by_name.items():\n            if len(durations) == 1:\n                # Single call: just report the duration\n                metrics[f\"time/{name}\"] = durations[0]\n            else:\n                # Multiple calls: report aggregates\n                metrics[f\"time/{name}:total\"] = sum(durations)\n                metrics[f\"time/{name}:count\"] = len(durations)\n                metrics[f\"time/{name}:mean\"] = sum(durations) / len(durations)\n                metrics[f\"time/{name}:max\"] = max(durations)\n\n        return metrics\n\n    def get_timing_metrics(self) -> dict[str, float]:\n        \"\"\"Get aggregated timing metrics including time/total.\n\n        Call this after the ``trace_iteration`` context manager has exited,\n        which sets ``_total_time``.\n        \"\"\"\n        metrics = self.aggregate()\n        if self._total_time is not None:\n            metrics[\"time/total\"] = self._total_time\n        return metrics\n\n    def merge_spans(self, spans: list[SpanRecord]) -> None:\n        \"\"\"Merge externally-collected spans (e.g. from worker processes) into this window.\"\"\"\n        with self._lock:\n            self.spans.extend(spans)\n\n    def get_span_records(self) -> list[dict[str, Any]]:\n        \"\"\"Get span records for Gantt chart rendering.\n\n        Uses wall-clock timestamps (time.time) so that spans from different\n        processes can be placed on a shared timeline.\n        \"\"\"\n        with self._lock:\n            spans = list(self.spans)\n\n        if not spans:\n            return []\n\n        # Use wall-clock times for positioning — comparable across processes\n        t0 = min(s.wall_start for s in spans)\n        return [\n            {\n                \"task\": s.name,\n                \"start\": datetime.datetime(2000, 1, 1)\n                + datetime.timedelta(seconds=s.wall_start - t0),\n                \"end\": datetime.datetime(2000, 1, 1) + datetime.timedelta(seconds=s.wall_end - t0),\n            }\n            for s in spans\n        ]\n\n    def write_spans_jsonl(self, path: Path | str, step: int) -> None:\n        \"\"\"Append span records for this iteration as one JSON line to the given file.\n\n        Format: ``{\"step\": N, \"spans\": [{\"name\": ..., \"duration\": ..., \"wall_start\": ..., \"wall_end\": ...}, ...]}``\n        \"\"\"\n        with self._lock:\n            spans = list(self.spans)\n\n        if not spans:\n            return\n\n        t0 = min(s.wall_start for s in spans)\n        span_dicts = [\n            {\n                \"name\": s.name,\n                \"duration\": s.end_time - s.start_time,\n                \"wall_start\": s.wall_start - t0,\n                \"wall_end\": s.wall_end - t0,\n            }\n            for s in spans\n        ]\n        line = json.dumps({\"step\": step, \"spans\": span_dicts})\n        with open(path, \"a\") as f:\n            f.write(line + \"\\n\")\n\n\n# Context variable to track the current iteration window\n_iteration_window: ContextVar[IterationWindow | None] = ContextVar(\n    \"_iteration_window\", default=None\n)\n\n\nclass TraceCollector:\n    \"\"\"Collects trace events and exports them in Chrome Trace/Perfetto Format.\"\"\"\n\n    def __init__(self, flush_interval_sec: float = 1.0, output_file: str = \"trace_events.jsonl\"):\n        self.event_queue: queue.Queue[TraceEvent] = queue.Queue()\n        self.flush_interval_sec = flush_interval_sec\n        self.output_file = output_file\n        self.shutdown_event = threading.Event()\n        self.flusher_thread = threading.Thread(target=self._flush_worker, daemon=True)\n        self.flusher_thread.start()\n\n        # Map of (pid, tid) to metadata event\n        self.metadata_events: dict[tuple[int, int], TraceEvent] = {}\n        self.next_fake_pid = 0\n        self.thread_id_to_fake_pid: dict[int, int] = {}\n\n    def add_event(self, event: TraceEvent):\n        \"\"\"Thread-safe addition of trace events.\"\"\"\n        self.event_queue.put(event)\n\n    def get_timestamp(self) -> float:\n        \"\"\"Get current timestamp in microseconds relative to start.\"\"\"\n        return time.perf_counter() * 1e6\n\n    def get_all_events_immediately_available(self) -> list[TraceEvent]:\n        \"\"\"Get all events that are immediately available.\"\"\"\n        events = []\n        while True:\n            try:\n                events.append(self.event_queue.get_nowait())\n            except queue.Empty:\n                break\n        return events\n\n    def _write_events(self, events: list[TraceEvent], f: TextIOWrapper) -> None:\n        for event in events:\n            # Map the event pids (thread ids) to fake pids. If pid numbers are large,\n            # Perfetto has issues rendering these as different groups of tracks\n            if event.pid not in self.thread_id_to_fake_pid:\n                self.thread_id_to_fake_pid[event.pid] = self.next_fake_pid\n                self.next_fake_pid += 1\n            event.pid = self.thread_id_to_fake_pid[event.pid]\n\n            # Only log the first metadata event for each pid/tid pair\n            if event.ph == EventType.METADATA:\n                if (event.pid, event.tid) in self.metadata_events:\n                    continue\n                self.metadata_events[(event.pid, event.tid)] = event\n\n            json.dump(event.to_dict(), f)\n            f.write(\"\\n\")\n        f.flush()\n\n    def _flush_worker(self):\n        \"\"\"Background thread worker that periodically flushes events to file.\"\"\"\n        # Use append mode to avoid overwriting previous events when resuming\n        # from a checkpoint\n        with open(self.output_file, \"a\") as f:\n            while not self.shutdown_event.is_set():\n                events_to_write = self.get_all_events_immediately_available()\n\n                # Collect events with a timeout to check shutdown periodically\n                try:\n                    # Get first event with timeout and any additional events that are immediately available\n                    event = self.event_queue.get(timeout=self.flush_interval_sec)\n                    events_to_write.append(event)\n                    events_to_write.extend(self.get_all_events_immediately_available())\n                except queue.Empty:\n                    # No events to flush, continue checking for shutdown\n                    continue\n                self._write_events(events_to_write, f)\n\n            # Flush remaining events on shutdown\n            self._write_events(self.get_all_events_immediately_available(), f)\n\n    def shutdown(self):\n        \"\"\"Shutdown the background flusher thread.\"\"\"\n        self.shutdown_event.set()\n        self.flusher_thread.join(timeout=5.0)\n\n\n# Global trace collector instance\n_trace_collector: TraceCollector | None = None\n\n\ndef _atexit_trace_shutdown():\n    global _trace_collector\n    if _trace_collector is not None:\n        _trace_collector.shutdown()\n        _trace_collector = None\n\n\natexit.register(_atexit_trace_shutdown)\n\n\ndef _instrument_sdk_clients() -> None:\n    \"\"\"Patch Tinker SDK client classes with @scope for automatic tracing.\"\"\"\n    import tinker\n\n    _methods_to_patch = {\n        tinker.TrainingClient: [\n            \"forward_async\",\n            \"forward_backward_async\",\n            \"forward_backward_custom_async\",\n            \"get_info_async\",\n            \"optim_step_async\",\n            \"save_state_async\",\n            \"load_state_async\",\n            \"load_state_with_optimizer_async\",\n            \"save_weights_for_sampler_async\",\n            \"save_weights_and_get_sampling_client_async\",\n            \"create_sampling_client_async\",\n        ],\n        tinker.SamplingClient: [\n            \"sample_async\",\n            \"compute_logprobs_async\",\n            \"get_base_model_async\",\n        ],\n    }\n\n    for cls, method_names in _methods_to_patch.items():\n        for method_name in method_names:\n            if hasattr(cls, method_name):\n                original = getattr(cls, method_name)\n                # Avoid double-wrapping\n                if not getattr(original, \"_scope_instrumented\", False):\n                    wrapped = scope(original)\n                    wrapped._scope_instrumented = True  # type: ignore[attr-defined]\n                    setattr(cls, method_name, wrapped)\n\n\ndef trace_init(\n    flush_interval_sec: float = 1.0,\n    output_file: str = \"trace_events.jsonl\",\n) -> None:\n    \"\"\"Initialize the trace collector.\n\n    Args:\n        flush_interval_sec: How often to flush trace events to disk.\n        output_file: Path for Perfetto trace output (JSONL format).\n    \"\"\"\n    global _trace_collector\n    _trace_collector = TraceCollector(flush_interval_sec, output_file)\n    _instrument_sdk_clients()\n\n\ndef trace_shutdown() -> None:\n    \"\"\"Shutdown the trace collector and flush any remaining events.\"\"\"\n    global _trace_collector\n    if _trace_collector is None:\n        return\n    _trace_collector.shutdown()\n    _trace_collector = None\n\n\n@dataclass\nclass FunctionCallContext:\n    \"\"\"Context information for a function call\"\"\"\n\n    scope_context: ScopeContext\n    coroutine_name: str\n    thread_name: str\n    category: str\n    thread_id: int\n\n\n@dataclass\nclass CreateTraceEventsResult:\n    begin_event: TraceEvent\n    metadata_coroutine_event: TraceEvent\n    metadata_thread_event: TraceEvent\n    function_call_context: FunctionCallContext\n\n\ndef _get_trace_thread_info() -> tuple[int, str, str]:\n    \"\"\"Get thread/coroutine info for trace events.\n\n    Returns (thread_id, thread_name, coroutine_name).\n    \"\"\"\n    thread_id = threading.current_thread().ident or 0\n    thread_name = threading.current_thread().name\n    try:\n        task = asyncio.current_task()\n        if task is None:\n            coroutine_name = f\"sync:{thread_name}\"\n        else:\n            coroutine_name = task.get_name()\n    except RuntimeError:\n        coroutine_name = f\"sync:{thread_name}\"\n    return thread_id, thread_name, coroutine_name\n\n\ndef _create_trace_events(name: str) -> CreateTraceEventsResult:\n    \"\"\"Create trace events and context information for a named span.\"\"\"\n    assert _trace_collector is not None, (\n        \"Trace collector must be initialized before creating trace events\"\n    )\n\n    thread_id, thread_name, coroutine_name = _get_trace_thread_info()\n    category = \"async\"\n\n    # Begin event for this function call\n    begin_event = TraceEvent(\n        name=name,\n        ph=EventType.BEGIN,\n        pid=thread_id,  # Process ID (we use thread ID as process)\n        tid=hash(coroutine_name) % 1000000,  # Track ID within the thread\n        ts=_trace_collector.get_timestamp(),\n        args={\n            \"track\": coroutine_name,\n            \"thread\": thread_name,\n        },\n        cat=category,\n    )\n\n    # Metadata events to identify the track names.\n    # In typical perfetto setups, a process has a group of tracks, where each track represnets a thread.\n    # In our case, a group of tracks represents a thread, and a track represents a coroutine running\n    # on that thread.\n    metadata_coroutine_event = TraceEvent(\n        name=\"thread_name\",\n        ph=EventType.METADATA,\n        pid=thread_id,\n        tid=hash(coroutine_name) % 1000000,\n        ts=0,\n        args={\"name\": coroutine_name},\n    )\n    metadata_thread_event = TraceEvent(\n        name=\"process_name\",\n        ph=EventType.METADATA,\n        pid=thread_id,\n        tid=0,\n        ts=0,\n        args={\"name\": f\"{thread_name} Thread\"},\n    )\n\n    return CreateTraceEventsResult(\n        begin_event,\n        metadata_coroutine_event,\n        metadata_thread_event,\n        FunctionCallContext(\n            scope_context=ScopeContext(),\n            coroutine_name=coroutine_name,\n            thread_name=thread_name,\n            category=category,\n            thread_id=thread_id,\n        ),\n    )\n\n\ndef _create_end_event(\n    name: str,\n    function_call_context: FunctionCallContext,\n) -> TraceEvent:\n    \"\"\"Create an end trace event for a named span.\"\"\"\n    assert _trace_collector is not None, (\n        \"Trace collector must be initialized before creating trace events\"\n    )\n\n    return TraceEvent(\n        name=name,\n        ph=EventType.END,\n        pid=function_call_context.thread_id,\n        tid=hash(function_call_context.coroutine_name) % 1000000,\n        ts=_trace_collector.get_timestamp(),\n        args={\n            \"track\": function_call_context.coroutine_name,\n            \"thread\": function_call_context.thread_name,\n            **function_call_context.scope_context.attributes,\n        },\n        cat=function_call_context.category,\n    )\n\n\ndef _make_scope_wrapper(func: Callable[..., Any], name: str) -> Callable[..., Any]:\n    \"\"\"Create a scope wrapper for a function with the given span name.\"\"\"\n\n    if inspect.iscoroutinefunction(func):\n\n        @functools.wraps(func)\n        async def async_wrapper(*args: Any, **kwargs: Any):\n            if _trace_collector is None:\n                # Still record into iteration window even without Perfetto tracing\n                window = _iteration_window.get(None)\n                if window is not None:\n                    t_start = time.perf_counter()\n                    try:\n                        return await func(*args, **kwargs)\n                    finally:\n                        window.record_span(name, t_start, time.perf_counter())\n                return await func(*args, **kwargs)\n\n            events_result = _create_trace_events(name)\n            _trace_collector.add_event(events_result.begin_event)\n            _trace_collector.add_event(events_result.metadata_coroutine_event)\n            _trace_collector.add_event(events_result.metadata_thread_event)\n\n            t_start = time.perf_counter()\n            token = None\n            try:\n                # Set context for nested calls\n                token = trace_context.set(events_result.function_call_context.scope_context)\n\n                # Execute the actual function\n                result = await func(*args, **kwargs)\n                return result\n\n            finally:\n                end_event = _create_end_event(name, events_result.function_call_context)\n                _trace_collector.add_event(end_event)\n\n                # Record into iteration window if active\n                window = _iteration_window.get(None)\n                if window is not None:\n                    window.record_span(name, t_start, time.perf_counter())\n\n                # Reset context\n                if token is not None:\n                    trace_context.reset(token)\n\n        return async_wrapper\n\n    else:\n\n        @functools.wraps(func)\n        def sync_wrapper(*args: Any, **kwargs: Any):\n            if _trace_collector is None:\n                # Still record into iteration window even without Perfetto tracing\n                window = _iteration_window.get(None)\n                if window is not None:\n                    t_start = time.perf_counter()\n                    try:\n                        return func(*args, **kwargs)\n                    finally:\n                        window.record_span(name, t_start, time.perf_counter())\n                return func(*args, **kwargs)\n\n            events_result = _create_trace_events(name)\n            _trace_collector.add_event(events_result.begin_event)\n            _trace_collector.add_event(events_result.metadata_coroutine_event)\n            _trace_collector.add_event(events_result.metadata_thread_event)\n\n            t_start = time.perf_counter()\n            token = None\n            try:\n                # Set context for nested calls\n                token = trace_context.set(events_result.function_call_context.scope_context)\n\n                # Execute the actual function\n                result = func(*args, **kwargs)\n                return result\n\n            finally:\n                end_event = _create_end_event(name, events_result.function_call_context)\n                _trace_collector.add_event(end_event)\n\n                # Record into iteration window if active\n                window = _iteration_window.get(None)\n                if window is not None:\n                    window.record_span(name, t_start, time.perf_counter())\n\n                # Reset context\n                if token is not None:\n                    trace_context.reset(token)\n\n        return sync_wrapper\n\n\ndef scope(func: Callable[..., Any]) -> Callable[..., Any]:\n    \"\"\"\n    Decorator for tracing both async and sync functions. In the resulting trace:\n    - Each track represents a coroutine (or a sync function if not a coroutine)\n    - A thread is a group of tracks, representing all the coroutines running on that thread\n\n    For better tracking, make sure to name all coroutines so that we can group them\n    properly in the trace.\n\n    Example usage:\n\n    from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context\n\n    @scope\n    async def foo():\n        await asyncio.sleep(0.1)\n        # Log additional attributes for this function call into the trace\n        context = get_scope_context()\n        context.attributes[\"foo\"] = 1\n        context.attributes[\"foo2\"] = \"abc\"\n        await bar()\n\n    @scope\n    async def bar():\n        # Name the coroutines so that we can group them properly in the trace\n        await asyncio.gather(\n            asyncio.create_task(baz(), name=\"baz\"),\n            asyncio.create_task(baz(), name=\"baz2\"),\n        )\n\n    @scope\n    async def main():\n        await foo()\n\n    if __name__ == \"__main__\":\n        trace_init()\n        asyncio.run(main())\n    \"\"\"\n    return _make_scope_wrapper(func, func.__name__)\n\n\n@contextlib.asynccontextmanager\nasync def scope_span(name: str):\n    \"\"\"Async context manager for inline named spans.\n\n    Records to both the Perfetto trace (if active) and the current IterationWindow.\n    Use this when you want to time a block of code with a semantic name rather than\n    decorating a function with ``@scope``.\n\n    Example::\n\n        async with scope_span(\"policy_sample\"):\n            result = await policy(observation, stop_condition)\n    \"\"\"\n    window = _iteration_window.get(None)\n\n    if _trace_collector is not None:\n        events_result = _create_trace_events(name)\n        _trace_collector.add_event(events_result.begin_event)\n        _trace_collector.add_event(events_result.metadata_coroutine_event)\n        _trace_collector.add_event(events_result.metadata_thread_event)\n\n        t_start = time.perf_counter()\n        try:\n            yield\n        finally:\n            end_event = _create_end_event(name, events_result.function_call_context)\n            _trace_collector.add_event(end_event)\n            if window is not None:\n                window.record_span(name, t_start, time.perf_counter())\n    elif window is not None:\n        t_start = time.perf_counter()\n        try:\n            yield\n        finally:\n            window.record_span(name, t_start, time.perf_counter())\n    else:\n        yield\n\n\n@contextlib.contextmanager\ndef scope_span_sync(name: str):\n    \"\"\"Sync context manager for inline named spans.\n\n    Same as ``scope_span`` but for synchronous code.\n\n    Example::\n\n        with scope_span_sync(\"data_processing\"):\n            result = process_data(batch)\n    \"\"\"\n    window = _iteration_window.get(None)\n\n    if _trace_collector is not None:\n        events_result = _create_trace_events(name)\n        _trace_collector.add_event(events_result.begin_event)\n        _trace_collector.add_event(events_result.metadata_coroutine_event)\n        _trace_collector.add_event(events_result.metadata_thread_event)\n\n        t_start = time.perf_counter()\n        try:\n            yield\n        finally:\n            end_event = _create_end_event(name, events_result.function_call_context)\n            _trace_collector.add_event(end_event)\n            if window is not None:\n                window.record_span(name, t_start, time.perf_counter())\n    elif window is not None:\n        t_start = time.perf_counter()\n        try:\n            yield\n        finally:\n            window.record_span(name, t_start, time.perf_counter())\n    else:\n        yield\n\n\ndef get_scope_context() -> ScopeContext:\n    \"\"\"\n    Call this to get the current scope's context. This allows the functions\n    to log additional attributes into the trace.\n\n    Example usage:\n\n    @scope\n    async def foo():\n        context = get_scope_context()\n        context.attributes[\"foo\"] = 1\n        context.attributes[\"foo2\"] = \"abc\"\n        await bar()\n    \"\"\"\n\n    result = trace_context.get(ScopeContext())\n    assert result is not None, \"Trace context is not set\"\n    return result\n\n\ndef update_scope_context(values: dict[str, Any]) -> None:\n    \"\"\"Update the current scope's context. Example usage:\n\n    @scope\n    async def foo(step: int):\n        update_scope_context({\"step\": step})\n        await bar()\n\n    \"\"\"\n    result = trace_context.get(ScopeContext())\n    assert result is not None, \"Trace context is not set\"\n    result.attributes.update(values)\n\n\ndef _build_gantt_chart(span_records: list[dict[str, Any]], step: int) -> Any:\n    \"\"\"Build a Plotly Gantt chart from span records. Returns a plotly Figure or None.\"\"\"\n    try:\n        import plotly.express as px  # type: ignore[reportMissingImports]\n    except ImportError:\n        logger.debug(\"plotly not installed, skipping Gantt chart\")\n        return None\n\n    if not span_records:\n        return None\n\n    fig = px.timeline(\n        span_records,\n        x_start=\"start\",\n        x_end=\"end\",\n        y=\"task\",\n        color=\"task\",\n        title=f\"Iteration {step} — Span Timeline\",\n    )\n    fig.update_layout(\n        xaxis_title=\"Time (relative)\",\n        yaxis_title=\"\",\n        showlegend=False,\n    )\n    return fig\n\n\ndef save_gantt_chart_html(window: IterationWindow, step: int, path: Path | str) -> None:\n    \"\"\"Build a Plotly Gantt chart from the window's spans and save as standalone HTML.\n\n    No-op if plotly is not installed or the window has no spans.\n    \"\"\"\n    span_records = window.get_span_records()\n    fig = _build_gantt_chart(span_records, step)\n    if fig is not None:\n        fig.write_html(str(path))\n\n\n@contextlib.contextmanager\ndef trace_iteration(step: int) -> Generator[IterationWindow, None, None]:\n    \"\"\"Context manager that marks a training iteration boundary.\n\n    Yields an ``IterationWindow`` that collects all ``@scope`` and ``scope_span``\n    spans within the block. After the block exits, call ``window.get_timing_metrics()``\n    to retrieve the aggregated timing dict (including ``time/total``).\n\n    Span names are flat (the function or span name), not hierarchical. If ``train_step``\n    calls ``forward_backward_async``, both appear as separate top-level keys::\n\n        time/train_step = 5.0              # inclusive (contains forward_backward)\n        time/forward_backward_async = 3.0  # just the inner call\n\n    For functions called multiple times (e.g. 160 concurrent ``sample_async``\n    calls), aggregated keys are produced::\n\n        time/sample_async:total = 480.0\n        time/sample_async:count = 160\n        time/sample_async:mean  = 3.0\n        time/sample_async:max   = 4.9\n\n    Example::\n\n        for i_batch in range(n_batches):\n            with trace_iteration(step=i_batch) as window:\n                await run_evals(...)\n                await gather_rollouts(...)\n                await train_step(...)\n            metrics.update(window.get_timing_metrics())\n            window.write_spans_jsonl(log_path / \"timing_spans.jsonl\", step=i_batch)\n            ml_logger.log_metrics(metrics, step=i_batch)\n    \"\"\"\n    window = IterationWindow()\n    token = _iteration_window.set(window)\n    t_start = time.perf_counter()\n    try:\n        yield window\n    finally:\n        window._total_time = time.perf_counter() - t_start\n        _iteration_window.reset(token)\n\n\ndef convert_jsonl_to_json_main():\n    \"\"\"Helper script to convert the trace events format into a visualizable format\"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Convert trace events from JSONL format to JSON format for visualization in chrome://tracing or https://ui.perfetto.dev/\"\n    )\n    parser.add_argument(\"trace_events_jsonl_file\", type=str)\n    parser.add_argument(\"output_json_file\", type=str)\n    args = parser.parse_args()\n\n    with open(args.trace_events_jsonl_file) as f:\n        events = [json.loads(line) for line in f]\n    with open(args.output_json_file, \"w\") as f:\n        json.dump(events, f)\n    print(f\"\"\"To view the trace:\n1. Navigate to chrome://tracing or https://ui.perfetto.dev/\n2. Load the trace file: {args.output_json_file}\"\"\")\n\n\nif __name__ == \"__main__\":\n    convert_jsonl_to_json_main()\n"
  },
  {
    "path": "tinker_cookbook/utils/trace_test.py",
    "content": "import asyncio\nimport contextlib\nimport inspect\nimport json\nimport tempfile\nimport threading\nimport time\nfrom pathlib import Path\nfrom unittest.mock import patch\n\nimport tinker\n\nfrom tinker_cookbook.utils.trace import (\n    IterationWindow,\n    SpanRecord,\n    _build_gantt_chart,\n    get_scope_context,\n    save_gantt_chart_html,\n    scope,\n    scope_span,\n    scope_span_sync,\n    trace_init,\n    trace_iteration,\n    trace_shutdown,\n    update_scope_context,\n)\n\n# --- Helpers ---\n\n\n@contextlib.contextmanager\ndef trace_session():\n    \"\"\"Start a trace session backed by a temporary JSONL file.\"\"\"\n    with tempfile.NamedTemporaryFile(suffix=\".jsonl\", delete=True) as f:\n        trace_init(output_file=f.name)\n        try:\n            yield f.name\n        finally:\n            trace_shutdown()\n\n\n# --- Decorated helpers for test_trace (multi-thread integration) ---\n#\n# These must be module-level because @scope captures __name__ at decoration time.\n# They are used only by test_trace.\n\n\n@scope\nasync def foo():\n    await asyncio.sleep(0.1)\n    context = get_scope_context()\n    context.attributes[\"foo\"] = \"foo\"\n    context.attributes[\"foo2\"] = 1\n    await bar()\n\n\n@scope\nasync def bar():\n    await asyncio.sleep(0.05)\n    context = get_scope_context()\n    context.attributes[\"bar\"] = 1\n    await baz()\n\n\n@scope\ndef ced():\n    pass\n\n\n@scope\nasync def baz():\n    await asyncio.sleep(0.02)\n    update_scope_context({\"baz\": \"baz\"})\n    ced()\n\n\n@scope\nasync def coroutine1():\n    await foo()\n    await asyncio.sleep(0.05)\n\n\n@scope\nasync def coroutine2():\n    await asyncio.sleep(0.15)\n    await foo()\n\n\n@scope\ndef sync_func():\n    pass\n\n\n@scope\nasync def work(thread_name: str):\n    task1 = asyncio.create_task(coroutine1(), name=f\"{thread_name}-coroutine-1\")\n    task2 = asyncio.create_task(coroutine2(), name=f\"{thread_name}-coroutine-2\")\n    sync_func()\n    await asyncio.gather(task1, task2)\n\n\n@scope\nasync def example_program():\n    @scope\n    def thread_target():\n        asyncio.run(work(\"secondary_thread\"))\n\n    thread = threading.Thread(target=thread_target, name=\"secondary_thread\")\n    thread.start()\n\n    await work(\"main_thread\")\n\n    thread.join()\n\n\n# --- @scope decorator ---\n\n\ndef test_trace():\n    with tempfile.NamedTemporaryFile(suffix=\".jsonl\", delete=True) as temp_file:\n        trace_init(output_file=temp_file.name)\n        asyncio.run(example_program())\n        trace_shutdown()\n\n        with open(temp_file.name) as f:\n            events = [json.loads(line) for line in f]\n\n        # There should be 2 process metadata events\n        num_metadata_pid_events = sum(\n            1 for event in events if event[\"ph\"] == \"M\" and event[\"tid\"] == 0\n        )\n        assert num_metadata_pid_events == 2\n        num_unique_pids = len({event[\"pid\"] for event in events if event[\"ph\"] != \"M\"})\n        assert num_unique_pids == 2\n\n        # main thread has 3: main, coroutine-1, coroutine-2\n        # secondary thread has 4: thread_target, work, coroutine-1, coroutine-2\n        num_metadata_tid_events = sum(\n            1 for event in events if event[\"ph\"] == \"M\" and event[\"tid\"] != 0\n        )\n        assert num_metadata_tid_events == 7\n        num_unique_tids = len({event[\"tid\"] for event in events if event[\"ph\"] != \"M\"})\n        assert num_unique_tids == 7\n\n    # Validate that attributes are set correctly\n    for event in events:\n        if event[\"ph\"] != \"E\":\n            continue\n        if event[\"name\"] == \"foo\":\n            assert event[\"args\"][\"foo\"] == \"foo\"\n            assert event[\"args\"][\"foo2\"] == 1\n        if event[\"name\"] == \"bar\":\n            assert event[\"args\"][\"bar\"] == 1\n        if event[\"name\"] == \"baz\":\n            assert event[\"args\"][\"baz\"] == \"baz\"\n\n\ndef test_scope_noop_async():\n    \"\"\"Async @scope passes through when no collector and no iteration window.\"\"\"\n\n    @scope\n    async def noop_async():\n        return 42\n\n    # No trace_init, no trace_iteration — should just return the value\n    result = asyncio.run(noop_async())\n    assert result == 42\n\n\ndef test_scope_noop_sync():\n    \"\"\"Sync @scope passes through when no collector and no iteration window.\"\"\"\n\n    @scope\n    def noop_sync():\n        return 99\n\n    # No trace_init, no trace_iteration — should just return the value\n    result = noop_sync()\n    assert result == 99\n\n\n# --- IterationWindow ---\n\n\ndef test_iteration_window_single_span():\n    window = IterationWindow()\n    window.record_span(\"train_step\", 0.0, 1.5)\n    metrics = window.aggregate()\n    assert metrics == {\"time/train_step\": 1.5}\n\n\ndef test_iteration_window_multiple_spans_same_name():\n    window = IterationWindow()\n    window.record_span(\"sample\", 0.0, 2.0)\n    window.record_span(\"sample\", 0.1, 3.0)\n    window.record_span(\"sample\", 0.2, 1.5)\n    metrics = window.aggregate()\n    assert metrics[\"time/sample:count\"] == 3\n    assert metrics[\"time/sample:total\"] == 2.0 + 2.9 + 1.3\n    assert abs(metrics[\"time/sample:mean\"] - (2.0 + 2.9 + 1.3) / 3) < 1e-9\n    assert metrics[\"time/sample:max\"] == 2.9\n\n\ndef test_iteration_window_mixed_spans():\n    window = IterationWindow()\n    window.record_span(\"eval\", 0.0, 1.0)\n    window.record_span(\"sample\", 1.0, 3.0)\n    window.record_span(\"sample\", 1.1, 2.5)\n    window.record_span(\"train\", 3.0, 4.0)\n    metrics = window.aggregate()\n    # eval: single call\n    assert metrics[\"time/eval\"] == 1.0\n    # sample: two calls\n    assert metrics[\"time/sample:count\"] == 2\n    # train: single call\n    assert metrics[\"time/train\"] == 1.0\n\n\ndef test_iteration_window_empty():\n    window = IterationWindow()\n    assert window.aggregate() == {}\n    assert window.get_span_records() == []\n\n\ndef test_iteration_window_span_records():\n    window = IterationWindow()\n    window.record_span(\"a\", 100.0, 101.0)\n    window.record_span(\"b\", 100.5, 102.0)\n    records = window.get_span_records()\n    assert len(records) == 2\n    assert records[0][\"task\"] == \"a\"\n    assert records[1][\"task\"] == \"b\"\n    # start times should be relative (first span starts at 0)\n    assert records[0][\"start\"] < records[1][\"start\"]\n\n\ndef test_merge_spans():\n    \"\"\"merge_spans integrates external spans into the window.\"\"\"\n    window = IterationWindow()\n    window.record_span(\"local\", 0.0, 1.0)\n\n    external = [\n        SpanRecord(name=\"worker\", start_time=0.5, end_time=2.0, wall_start=1000.5, wall_end=1002.0),\n    ]\n    window.merge_spans(external)\n\n    metrics = window.aggregate()\n    assert \"time/local\" in metrics\n    assert \"time/worker\" in metrics\n\n    records = window.get_span_records()\n    assert len(records) == 2\n\n\ndef test_get_timing_metrics():\n    \"\"\"get_timing_metrics includes time/total when set by trace_iteration.\"\"\"\n    window = IterationWindow()\n    window.record_span(\"op\", 0.0, 1.0)\n    window._total_time = 2.5\n    metrics = window.get_timing_metrics()\n    assert metrics[\"time/op\"] == 1.0\n    assert metrics[\"time/total\"] == 2.5\n\n\ndef test_get_timing_metrics_without_total():\n    \"\"\"get_timing_metrics works without time/total (no trace_iteration).\"\"\"\n    window = IterationWindow()\n    window.record_span(\"op\", 0.0, 1.0)\n    metrics = window.get_timing_metrics()\n    assert metrics[\"time/op\"] == 1.0\n    assert \"time/total\" not in metrics\n\n\n# --- write_spans_jsonl ---\n\n\ndef test_write_spans_jsonl():\n    \"\"\"write_spans_jsonl appends one JSON line per call.\"\"\"\n    window = IterationWindow()\n    window.record_span(\"a\", 100.0, 101.5)\n    window.record_span(\"b\", 100.2, 102.0)\n\n    with tempfile.NamedTemporaryFile(suffix=\".jsonl\", delete=True, mode=\"w\") as f:\n        path = f.name\n\n    window.write_spans_jsonl(path, step=0)\n    window.write_spans_jsonl(path, step=1)\n\n    with open(path) as f:\n        lines = [json.loads(line) for line in f]\n\n    assert len(lines) == 2\n    assert lines[0][\"step\"] == 0\n    assert lines[1][\"step\"] == 1\n    assert len(lines[0][\"spans\"]) == 2\n    assert lines[0][\"spans\"][0][\"name\"] == \"a\"\n    assert lines[0][\"spans\"][1][\"name\"] == \"b\"\n    assert abs(lines[0][\"spans\"][0][\"duration\"] - 1.5) < 1e-9\n    # wall_start of first span should be ~0 (relative)\n    assert lines[0][\"spans\"][0][\"wall_start\"] < 0.1\n\n    Path(path).unlink(missing_ok=True)\n\n\ndef test_write_spans_jsonl_empty_window():\n    \"\"\"write_spans_jsonl is a no-op for empty windows.\"\"\"\n    window = IterationWindow()\n    with tempfile.NamedTemporaryFile(suffix=\".jsonl\", delete=True, mode=\"w\") as f:\n        path = f.name\n\n    window.write_spans_jsonl(path, step=0)\n    assert not Path(path).exists()\n\n\n# --- trace_iteration ---\n\n\ndef test_trace_iteration_collects_scoped_spans():\n    \"\"\"trace_iteration collects spans from @scope-decorated functions.\"\"\"\n\n    @scope\n    async def fast_op():\n        await asyncio.sleep(0.01)\n\n    @scope\n    async def slow_op():\n        await asyncio.sleep(0.05)\n\n    async def run():\n        with trace_session():\n            with trace_iteration(step=0) as window:\n                await fast_op()\n                await slow_op()\n            return window\n\n    window = asyncio.run(run())\n    metrics = window.get_timing_metrics()\n    assert \"time/total\" in metrics\n    assert \"time/fast_op\" in metrics\n    assert \"time/slow_op\" in metrics\n    assert metrics[\"time/slow_op\"] > metrics[\"time/fast_op\"]\n\n\ndef test_trace_iteration_aggregates_repeated_calls():\n    \"\"\"Repeated calls to the same @scope function produce aggregate metrics.\"\"\"\n\n    @scope\n    async def repeated_op():\n        await asyncio.sleep(0.01)\n\n    async def run():\n        with trace_session():\n            with trace_iteration(step=5) as window:\n                await asyncio.gather(\n                    repeated_op(),\n                    repeated_op(),\n                    repeated_op(),\n                )\n            return window\n\n    window = asyncio.run(run())\n    metrics = window.get_timing_metrics()\n    assert metrics[\"time/repeated_op:count\"] == 3\n    assert \"time/repeated_op:mean\" in metrics\n    assert \"time/repeated_op:max\" in metrics\n    assert \"time/repeated_op:total\" in metrics\n\n\ndef test_trace_iteration_without_trace_init():\n    \"\"\"trace_iteration works even without trace_init (no Perfetto, just span collection).\"\"\"\n\n    @scope\n    async def some_work():\n        await asyncio.sleep(0.01)\n\n    async def run():\n        # No trace_init — _trace_collector is None\n        with trace_iteration(step=0) as window:\n            await some_work()\n        return window\n\n    window = asyncio.run(run())\n    metrics = window.get_timing_metrics()\n    assert \"time/some_work\" in metrics\n    assert \"time/total\" in metrics\n\n\ndef test_trace_iteration_with_perfetto_only():\n    \"\"\"trace_iteration with Perfetto but caller doesn't use timing metrics.\"\"\"\n\n    @scope\n    async def op():\n        await asyncio.sleep(0.01)\n\n    async def run():\n        with trace_session():\n            with trace_iteration(step=0) as window:\n                await op()\n            return window\n\n    window = asyncio.run(run())\n    # Caller can choose to ignore the window — no crash\n    assert \"time/op\" in window.get_timing_metrics()\n\n\ndef test_trace_iteration_sync_functions():\n    \"\"\"trace_iteration collects spans from sync @scope-decorated functions.\"\"\"\n\n    @scope\n    def sync_work():\n        time.sleep(0.01)\n\n    async def run():\n        with trace_session():\n            with trace_iteration(step=0) as window:\n                sync_work()\n                sync_work()\n            return window\n\n    window = asyncio.run(run())\n    metrics = window.get_timing_metrics()\n    assert metrics[\"time/sync_work:count\"] == 2\n\n\ndef test_trace_iteration_on_exception():\n    \"\"\"trace_iteration still captures partial timing when an exception occurs.\"\"\"\n\n    @scope\n    async def succeeds():\n        await asyncio.sleep(0.01)\n\n    @scope\n    async def fails():\n        await asyncio.sleep(0.01)\n        raise ValueError(\"boom\")\n\n    async def run():\n        with trace_session():\n            with trace_iteration(step=0) as window:\n                try:\n                    await succeeds()\n                    await fails()\n                except ValueError:\n                    pass\n            return window\n\n    window = asyncio.run(run())\n    metrics = window.get_timing_metrics()\n    assert \"time/total\" in metrics\n    assert \"time/succeeds\" in metrics\n    assert \"time/fails\" in metrics\n\n\ndef test_trace_iteration_nested():\n    \"\"\"Nested trace_iteration: inner window is independent from outer.\"\"\"\n\n    @scope\n    async def outer_op():\n        await asyncio.sleep(0.01)\n\n    @scope\n    async def inner_op():\n        await asyncio.sleep(0.01)\n\n    async def run():\n        with trace_session():\n            with trace_iteration(step=0) as outer_window:\n                await outer_op()\n                with trace_iteration(step=100) as inner_window:\n                    await inner_op()\n            return outer_window, inner_window\n\n    outer_window, inner_window = asyncio.run(run())\n\n    # Inner should only have inner_op\n    inner_metrics = inner_window.get_timing_metrics()\n    assert \"time/inner_op\" in inner_metrics\n    assert \"time/outer_op\" not in inner_metrics\n\n    # Outer should have outer_op (inner_op was captured by inner window, not outer)\n    outer_metrics = outer_window.get_timing_metrics()\n    assert \"time/outer_op\" in outer_metrics\n\n\n# --- scope_span ---\n\n\ndef test_scope_span_async():\n    \"\"\"scope_span records to iteration window.\"\"\"\n\n    async def run():\n        with trace_iteration(step=0) as window:\n            async with scope_span(\"my_span\"):\n                await asyncio.sleep(0.01)\n        return window\n\n    window = asyncio.run(run())\n    metrics = window.get_timing_metrics()\n    assert \"time/my_span\" in metrics\n    assert metrics[\"time/my_span\"] >= 0.01\n\n\ndef test_scope_span_with_perfetto():\n    \"\"\"scope_span records to both Perfetto and iteration window.\"\"\"\n\n    with tempfile.NamedTemporaryFile(suffix=\".jsonl\", delete=False) as tmp:\n        trace_file = tmp.name\n\n    try:\n\n        async def run():\n            trace_init(output_file=trace_file)\n            try:\n                with trace_iteration(step=0) as window:\n                    async with scope_span(\"traced_span\"):\n                        await asyncio.sleep(0.01)\n                return window\n            finally:\n                trace_shutdown()\n\n        window = asyncio.run(run())\n\n        # Check iteration window\n        metrics = window.get_timing_metrics()\n        assert \"time/traced_span\" in metrics\n\n        # Check Perfetto trace file has the span\n        with open(trace_file) as f:\n            events = [json.loads(line) for line in f]\n        span_events = [e for e in events if e.get(\"name\") == \"traced_span\"]\n        assert len(span_events) >= 2  # BEGIN + END\n    finally:\n        Path(trace_file).unlink(missing_ok=True)\n\n\ndef test_scope_span_noop():\n    \"\"\"scope_span is a no-op when no collector and no iteration window.\"\"\"\n\n    async def run():\n        async with scope_span(\"should_not_crash\"):\n            return 42\n\n    result = asyncio.run(run())\n    assert result == 42\n\n\ndef test_scope_span_sync():\n    \"\"\"scope_span_sync records to iteration window.\"\"\"\n\n    with trace_iteration(step=0) as window:\n        with scope_span_sync(\"sync_span\"):\n            time.sleep(0.01)\n\n    metrics = window.get_timing_metrics()\n    assert \"time/sync_span\" in metrics\n\n\ndef test_scope_span_multiple():\n    \"\"\"Multiple scope_span calls with the same name produce aggregates.\"\"\"\n\n    async def run():\n        with trace_iteration(step=0) as window:\n            for _ in range(3):\n                async with scope_span(\"repeated\"):\n                    await asyncio.sleep(0.01)\n        return window\n\n    window = asyncio.run(run())\n    metrics = window.get_timing_metrics()\n    assert metrics[\"time/repeated:count\"] == 3\n    assert \"time/repeated:total\" in metrics\n\n\ndef test_scope_span_on_exception():\n    \"\"\"scope_span still records the span when the block raises.\"\"\"\n\n    async def run():\n        with trace_iteration(step=0) as window:\n            async with scope_span(\"before_error\"):\n                await asyncio.sleep(0.01)\n            try:\n                async with scope_span(\"erroring\"):\n                    await asyncio.sleep(0.01)\n                    raise ValueError(\"boom\")\n            except ValueError:\n                pass\n        return window\n\n    window = asyncio.run(run())\n    metrics = window.get_timing_metrics()\n    assert \"time/before_error\" in metrics\n    assert \"time/erroring\" in metrics\n\n\n# --- Gantt chart ---\n\n\ndef test_build_gantt_chart_success():\n    \"\"\"_build_gantt_chart returns a figure when plotly is available and spans are non-empty.\"\"\"\n    import datetime\n\n    span_records = [\n        {\n            \"task\": \"a\",\n            \"start\": datetime.datetime(2000, 1, 1),\n            \"end\": datetime.datetime(2000, 1, 1, 0, 0, 1),\n        },\n        {\n            \"task\": \"b\",\n            \"start\": datetime.datetime(2000, 1, 1, 0, 0, 0, 500000),\n            \"end\": datetime.datetime(2000, 1, 1, 0, 0, 2),\n        },\n    ]\n    fig = _build_gantt_chart(span_records, step=0)\n    # If plotly is installed, we get a figure; if not, None\n    try:\n        import plotly  # noqa: F401\n\n        assert fig is not None\n    except ImportError:\n        assert fig is None\n\n\ndef test_build_gantt_chart_empty_spans():\n    \"\"\"_build_gantt_chart returns None for empty span list.\"\"\"\n    # Even if plotly is installed, empty spans should return None\n    fig = _build_gantt_chart([], step=0)\n    assert fig is None\n\n\ndef test_build_gantt_chart_no_plotly():\n    \"\"\"_build_gantt_chart returns None when plotly is not importable.\"\"\"\n    import datetime\n\n    span_records = [\n        {\n            \"task\": \"a\",\n            \"start\": datetime.datetime(2000, 1, 1),\n            \"end\": datetime.datetime(2000, 1, 1, 0, 0, 1),\n        },\n    ]\n    with patch.dict(\"sys.modules\", {\"plotly\": None, \"plotly.express\": None}):\n        fig = _build_gantt_chart(span_records, step=0)\n    assert fig is None\n\n\ndef test_save_gantt_chart_html():\n    \"\"\"save_gantt_chart_html writes an HTML file when plotly is available.\"\"\"\n    window = IterationWindow()\n    window.record_span(\"a\", 100.0, 101.0)\n    window.record_span(\"b\", 100.5, 102.0)\n\n    with tempfile.NamedTemporaryFile(suffix=\".html\", delete=False) as f:\n        path = Path(f.name)\n\n    save_gantt_chart_html(window, step=0, path=path)\n\n    try:\n        import plotly  # noqa: F401\n\n        assert path.exists()\n        content = path.read_text()\n        assert \"plotly\" in content.lower() or \"Plotly\" in content\n    except ImportError:\n        # plotly not installed — file should not be created\n        pass\n    finally:\n        path.unlink(missing_ok=True)\n\n\n# --- SDK client instrumentation ---\n\n\ndef test_sdk_client_instrumentation_covers_all_async_methods():\n    \"\"\"Tripwire: catches new async methods added to Tinker SDK clients.\n\n    If this test fails after a tinker dependency bump, add the new method(s) to\n    _instrument_sdk_clients in trace.py.\n    \"\"\"\n    # Collect all public async methods from SDK clients\n    sdk_async_methods: dict[type, set[str]] = {}\n    for cls in (tinker.TrainingClient, tinker.SamplingClient):\n        methods = set()\n        for name in dir(cls):\n            if name.startswith(\"_\"):\n                continue\n            attr = getattr(cls, name, None)\n            if attr is not None and inspect.iscoroutinefunction(attr):\n                methods.add(name)\n        sdk_async_methods[cls] = methods\n\n    # Instrument via trace_init and check all are wrapped\n    with trace_session():\n        for cls, methods in sdk_async_methods.items():\n            for method_name in methods:\n                original = getattr(cls, method_name)\n                assert getattr(original, \"_scope_instrumented\", False), (\n                    f\"{cls.__name__}.{method_name} is an async method but not instrumented by \"\n                    f\"_instrument_sdk_clients. Add it to the method list in trace.py.\"\n                )\n\n\ndef test_scope_double_wrapping_prevention():\n    \"\"\"_instrument_sdk_clients is idempotent — calling trace_init twice doesn't double-wrap.\"\"\"\n    with trace_session():\n        first_ref = tinker.TrainingClient.forward_backward_async\n        assert getattr(first_ref, \"_scope_instrumented\", False)\n\n    # Second trace_init — should not re-wrap\n    with trace_session():\n        second_ref = tinker.TrainingClient.forward_backward_async\n        assert getattr(second_ref, \"_scope_instrumented\", False)\n        # Same wrapper object — not double-wrapped\n        assert first_ref is second_ref\n\n\nif __name__ == \"__main__\":\n    trace_init()\n    asyncio.run(example_program())\n    trace_shutdown()\n"
  },
  {
    "path": "tinker_cookbook/weights/__init__.py",
    "content": "\"\"\"Weight lifecycle utilities for Tinker training.\n\nProvides functions for downloading, building, and publishing trained model\nweights. The typical workflow is:\n\n    download → build → publish\n\nEach function takes local paths as input/output, making them composable\nand independently testable.\n\nExample::\n\n    from tinker_cookbook import weights\n\n    adapter_dir = weights.download(\n        tinker_path=\"tinker://run-id/sampler_weights/final\",\n        output_dir=\"./adapter\",\n    )\n    weights.build_hf_model(\n        base_model=\"Qwen/Qwen3.5-35B-A3B\",\n        adapter_path=adapter_dir,\n        output_path=\"./model\",\n    )\n    weights.publish_to_hf_hub(model_path=\"./model\", repo_id=\"user/my-finetuned-model\")\n\"\"\"\n\nfrom tinker_cookbook.weights._download import download\nfrom tinker_cookbook.weights._export import build_hf_model\nfrom tinker_cookbook.weights._publish import publish_to_hf_hub\n\n__all__ = [\n    \"download\",\n    \"build_hf_model\",\n    \"build_lora_adapter\",\n    \"publish_to_hf_hub\",\n]\n\n\ndef build_lora_adapter(\n    *,\n    base_model: str,\n    adapter_path: str,\n    output_path: str,\n) -> None:\n    \"\"\"Convert a Tinker LoRA adapter to standard LoRA format.\n\n    The output can be loaded directly by vLLM (``--lora-modules``),\n    SGLang, or any framework supporting LoRA adapters without merging\n    into base model weights.\n\n    Args:\n        base_model: HuggingFace model name or local path. Needed to\n            resolve model-specific weight naming conventions.\n        adapter_path: Local path to the Tinker adapter directory\n            (must contain ``adapter_model.safetensors`` and\n            ``adapter_config.json``).\n        output_path: Directory where the standard LoRA adapter will\n            be saved.\n    \"\"\"\n    raise NotImplementedError(\n        \"build_lora_adapter is not yet implemented. \"\n        \"Use build_hf_model to merge the adapter into a full HF model instead.\"\n    )\n"
  },
  {
    "path": "tinker_cookbook/weights/_artifacts.py",
    "content": "\"\"\"Model artifact I/O utilities for weight export.\n\nProvides utilities for reading safetensors metadata, writing sharded output,\nloading adapters, resolving model directories, and copying non-weight files.\nUsed by both standard export strategies (``_export/_full.py``,\n``_export/_shard.py``) and model-specific export modules.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nimport os\nimport shutil\nfrom pathlib import Path\n\nimport torch\nfrom safetensors import safe_open\nfrom safetensors.torch import load_file, save_file\n\nlogger = logging.getLogger(__name__)\n\n# Custom model code files to copy (for trust_remote_code models).\n_MODEL_CODE_PATTERNS = (\"*.py\",)\n\n_MAX_SHARD_SIZE = 10 * (1024**3)  # 10 GB\n\n\ndef copy_artifact_file(src: Path, dst: Path) -> None:\n    \"\"\"Copy file contents without preserving source metadata.\n\n    Some output destinations (for example GCS/FUSE mounts) reject the timestamp\n    updates that `shutil.copy2()` performs via `copystat()`. Export artifacts do\n    not require source mtimes, so a content-only copy is sufficient here.\n    \"\"\"\n    shutil.copyfile(src, dst)\n\n\n# ---------------------------------------------------------------------------\n# Reading model metadata without loading weights\n# ---------------------------------------------------------------------------\n\n\ndef _raise_no_safetensors(model_dir: Path) -> None:\n    \"\"\"Raise FileNotFoundError with a helpful message for missing safetensors.\"\"\"\n    bin_files = sorted(model_dir.glob(\"*.bin\"))\n    if bin_files:\n        raise FileNotFoundError(\n            f\"No .safetensors files found in {model_dir}. \"\n            f\"Found {len(bin_files)} .bin file(s) — this model may use the older PyTorch format. \"\n            f\"Try merge_strategy='full' which loads via from_pretrained and handles both formats.\"\n        )\n    raise FileNotFoundError(\n        f\"No .safetensors files found in {model_dir}. \"\n        f\"Ensure the model has been fully downloaded, or try merge_strategy='full'.\"\n    )\n\n\ndef get_model_state_keys(model_dir: Path) -> set[str]:\n    \"\"\"Get all weight key names from safetensors files without loading tensor data.\n\n    Uses ``safetensors.safe_open`` to read headers only, which is fast and\n    uses negligible memory regardless of model size.\n\n    Args:\n        model_dir: Directory containing ``.safetensors`` files.\n\n    Returns:\n        Set of all tensor key names across all shard files.\n\n    Raises:\n        FileNotFoundError: If no ``.safetensors`` files are found.\n    \"\"\"\n    return set(get_model_state_shapes(model_dir).keys())\n\n\ndef get_model_state_shapes(model_dir: Path) -> dict[str, tuple[int, ...]]:\n    \"\"\"Get shape for each weight key from safetensors headers without loading tensor data.\n\n    Uses ``safetensors.safe_open`` to read headers only. This is fast and uses\n    negligible memory regardless of model size. Useful for upfront shape\n    validation before loading any weight shards.\n\n    Args:\n        model_dir: Directory containing ``.safetensors`` files.\n\n    Returns:\n        Dict mapping tensor key names to their shapes.\n\n    Raises:\n        FileNotFoundError: If no ``.safetensors`` files are found.\n    \"\"\"\n    shard_files = sorted(model_dir.glob(\"*.safetensors\"))\n    if not shard_files:\n        _raise_no_safetensors(model_dir)\n\n    shapes: dict[str, tuple[int, ...]] = {}\n    for sf_path in shard_files:\n        with safe_open(str(sf_path), framework=\"pt\") as f:\n            for key in f.keys():  # noqa: SIM118 — safe_open doesn't support `in`\n                shapes[key] = tuple(f.get_slice(key).get_shape())\n    return shapes\n\n\ndef get_shard_files(model_dir: Path) -> list[str]:\n    \"\"\"Get sorted list of safetensors shard filenames in a model directory.\n\n    Prefers reading ``model.safetensors.index.json`` for the canonical shard\n    list. Falls back to globbing for ``.safetensors`` files.\n\n    Args:\n        model_dir: Directory containing the model shards.\n\n    Returns:\n        Sorted list of shard filenames (not full paths).\n\n    Raises:\n        FileNotFoundError: If no ``.safetensors`` files are found.\n    \"\"\"\n    index_path = model_dir / \"model.safetensors.index.json\"\n    if index_path.exists():\n        with open(index_path) as f:\n            weight_map = json.load(f)[\"weight_map\"]\n        return sorted(set(weight_map.values()))\n\n    shard_files = sorted(model_dir.glob(\"*.safetensors\"))\n    if not shard_files:\n        _raise_no_safetensors(model_dir)\n    return [f.name for f in shard_files]\n\n\n# ---------------------------------------------------------------------------\n# Model directory resolution\n# ---------------------------------------------------------------------------\n\n\ndef resolve_model_dir(base_model: str) -> Path:\n    \"\"\"Resolve a HuggingFace model name or local path to a local directory.\n\n    If ``base_model`` is already a local directory, returns it directly.\n    Otherwise downloads from HuggingFace Hub via ``snapshot_download``.\n\n    Args:\n        base_model: HuggingFace model name (e.g. ``\"Qwen/Qwen3-8B\"``) or\n            local path to a model directory.\n\n    Returns:\n        Path to local directory containing model files.\n    \"\"\"\n    if os.path.isdir(base_model):\n        logger.info(\"Using local model directory: %s\", base_model)\n        return Path(base_model)\n\n    from huggingface_hub import snapshot_download\n\n    logger.info(\"Downloading model files for %s\", base_model)\n    local_dir = snapshot_download(repo_id=base_model)\n    return Path(local_dir)\n\n\n# ---------------------------------------------------------------------------\n# Adapter loading\n# ---------------------------------------------------------------------------\n\n\ndef load_adapter_weights(adapter_dir: Path) -> tuple[dict[str, torch.Tensor], dict]:\n    \"\"\"Load adapter weights and config from disk.\n\n    Args:\n        adapter_dir: Directory containing ``adapter_model.safetensors`` and\n            ``adapter_config.json``.\n\n    Returns:\n        Tuple of ``(weights_dict, config_dict)``.\n\n    Raises:\n        FileNotFoundError: If adapter files are missing.\n    \"\"\"\n    adapter_dir = adapter_dir.expanduser().resolve()\n    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    safetensors_path = adapter_dir / \"adapter_model.safetensors\"\n    if not safetensors_path.exists():\n        raise FileNotFoundError(f\"Adapter weights not found: {safetensors_path}\")\n\n    config_path = adapter_dir / \"adapter_config.json\"\n    if not config_path.exists():\n        raise FileNotFoundError(f\"Adapter config not found: {config_path}\")\n\n    weights = load_file(str(safetensors_path), device=device)\n    with open(config_path) as f:\n        config = json.load(f)\n    return weights, config\n\n\n# ---------------------------------------------------------------------------\n# Non-weight file copying\n# ---------------------------------------------------------------------------\n\n\ndef copy_model_code_files(model_dir: Path, output_path: Path) -> None:\n    \"\"\"Copy custom model code files (``*.py``) to the output directory.\n\n    Some model architectures require ``trust_remote_code=True`` and ship\n    custom Python files (e.g. ``configuration_*.py``, ``modeling_*.py``).\n    This copies those files so the merged model can be loaded standalone.\n\n    Only copies ``*.py`` files. Config and tokenizer files are handled\n    separately via HF APIs (``AutoConfig.save_pretrained``, etc.) to\n    avoid accidentally copying stale index files or other artifacts that\n    could break downstream loaders like vLLM/SGLang.\n\n    Args:\n        model_dir: Source model directory.\n        output_path: Destination directory (must exist).\n    \"\"\"\n    for pattern in _MODEL_CODE_PATTERNS:\n        for item in sorted(model_dir.glob(pattern)):\n            dest = output_path / item.name\n            if not dest.exists():\n                copy_artifact_file(item, dest)\n                logger.debug(\"Copied %s\", item.name)\n\n\n# ---------------------------------------------------------------------------\n# ShardWriter — accumulate and write safetensors shards\n# ---------------------------------------------------------------------------\n\n\nclass ShardWriter:\n    \"\"\"Accumulates tensors and writes numbered safetensors shard files.\n\n    Writes to temporary files during processing, then renames to final names\n    in :meth:`finalize`. This ensures partial failures don't leave behind\n    confusingly-named output files.\n\n    Args:\n        output_path: Directory where shard files will be written. Must exist.\n        max_shard_size: Maximum size (bytes) per output shard. Default 10 GB.\n    \"\"\"\n\n    def __init__(self, output_path: Path, max_shard_size: int = _MAX_SHARD_SIZE):\n        self._output_path = output_path\n        self._max_shard_size = max_shard_size\n        self._pending: dict[str, torch.Tensor] = {}\n        self._pending_size: int = 0\n        self._shard_count: int = 0\n        self._shard_keys: list[list[str]] = []\n        self._total_size: int = 0\n\n    def add_tensor(self, key: str, tensor: torch.Tensor) -> None:\n        \"\"\"Add a tensor to the current shard buffer.\n\n        Automatically flushes the buffer when adding this tensor would exceed\n        ``max_shard_size``.\n        \"\"\"\n        size = tensor.nelement() * tensor.element_size()\n        if self._pending and self._pending_size + size > self._max_shard_size:\n            self.flush()\n        self._pending[key] = tensor\n        self._pending_size += size\n        self._total_size += size\n\n    def flush(self) -> None:\n        \"\"\"Write buffered tensors to a temporary shard file.\"\"\"\n        if not self._pending:\n            return\n        # Use next shard number for temp file name, but only commit the\n        # count increment after save_file succeeds — avoids inconsistent\n        # state if the write fails (e.g. disk full).\n        next_idx = self._shard_count + 1\n        temp_name = f\"shard-{next_idx:05d}.tmp.safetensors\"\n        save_file(self._pending, str(self._output_path / temp_name))\n        # Write succeeded — commit state updates\n        self._shard_count = next_idx\n        self._shard_keys.append(list(self._pending.keys()))\n        logger.debug(\"Flushed %d tensors to %s\", len(self._pending), temp_name)\n        self._pending = {}\n        self._pending_size = 0\n\n    def finalize(self) -> dict[str, str]:\n        \"\"\"Flush remaining tensors, rename temps to final names, return weight map.\n\n        Returns:\n            Dict mapping tensor key to shard filename, suitable for\n            ``model.safetensors.index.json``.\n        \"\"\"\n        self.flush()\n        total = self._shard_count\n        weight_map: dict[str, str] = {}\n\n        for i in range(total):\n            temp_name = f\"shard-{i + 1:05d}.tmp.safetensors\"\n            if total == 1:\n                final_name = \"model.safetensors\"\n            else:\n                final_name = f\"model-{i + 1:05d}-of-{total:05d}.safetensors\"\n            (self._output_path / temp_name).rename(self._output_path / final_name)\n            for key in self._shard_keys[i]:\n                weight_map[key] = final_name\n\n        logger.info(\n            \"Wrote %d output shard(s), total %.1f GB\",\n            total,\n            self._total_size / (1024**3),\n        )\n        return weight_map\n\n    @property\n    def total_size(self) -> int:\n        \"\"\"Total bytes of tensors written (including pending).\"\"\"\n        return self._total_size\n"
  },
  {
    "path": "tinker_cookbook/weights/_download.py",
    "content": "\"\"\"Download checkpoint weights from Tinker storage.\"\"\"\n\nfrom __future__ import annotations\n\nimport tarfile\nimport tempfile\nimport urllib.error\nimport urllib.request\nfrom pathlib import Path\n\nimport tinker\n\nfrom tinker_cookbook.exceptions import WeightsDownloadError\n\n\ndef download(*, tinker_path: str, output_dir: str, base_url: str | None = None) -> str:\n    \"\"\"Download a checkpoint from Tinker storage to local disk.\n\n    Fetches a signed URL via the Tinker SDK, downloads the archive, and\n    extracts it with security validation (rejects symlinks and path\n    traversal).\n\n    Args:\n        tinker_path: Tinker checkpoint path, e.g.\n            ``\"tinker://<run_id>/sampler_weights/final\"``.\n        output_dir: Local directory where the checkpoint will be extracted.\n        base_url: Custom Tinker service URL. If ``None`` (default), uses\n            the default Tinker service endpoint (or ``TINKER_BASE_URL``\n            environment variable if set).\n\n    Returns:\n        Path to the extracted checkpoint directory.\n\n    Raises:\n        WeightsDownloadError: If the archive contains unsafe entries.\n        urllib.error.URLError: If the download fails.\n\n    Example::\n\n        from tinker_cookbook import weights\n\n        # Download from default Tinker service\n        adapter_dir = weights.download(\n            tinker_path=\"tinker://run-id/sampler_weights/final\",\n            output_dir=\"./adapter\",\n        )\n\n        # Download from a custom Tinker deployment\n        adapter_dir = weights.download(\n            tinker_path=\"tinker://run-id/sampler_weights/final\",\n            output_dir=\"./adapter\",\n            base_url=\"https://tinker.my-company.com\",\n        )\n    \"\"\"\n    kwargs: dict = {}\n    if base_url is not None:\n        kwargs[\"base_url\"] = base_url\n    try:\n        sc = tinker.ServiceClient(**kwargs)\n        rc = sc.create_rest_client()\n    except Exception as e:\n        raise WeightsDownloadError(\n            \"Failed to connect to Tinker service. \"\n            \"Ensure TINKER_API_KEY is set and the service is reachable.\"\n        ) from e\n\n    try:\n        response = rc.get_checkpoint_archive_url_from_tinker_path(tinker_path).result()\n    except Exception as e:\n        raise WeightsDownloadError(\n            f\"Failed to get download URL for {tinker_path!r}. \"\n            f\"Check that the checkpoint path is valid and the checkpoint has not expired.\"\n        ) from e\n\n    out = Path(output_dir)\n    out.mkdir(parents=True, exist_ok=True)\n\n    with tempfile.NamedTemporaryFile(suffix=\".tar\", delete=False) as tmp:\n        tmp_path = Path(tmp.name)\n    try:\n        try:\n            urllib.request.urlretrieve(response.url, str(tmp_path))\n        except urllib.error.URLError as e:\n            raise WeightsDownloadError(\n                \"Failed to download checkpoint archive from signed URL. \"\n                \"The URL may have expired — try downloading again.\"\n            ) from e\n        _safe_extract_tar(tmp_path, out)\n    finally:\n        tmp_path.unlink(missing_ok=True)\n\n    return output_dir\n\n\ndef _safe_extract_tar(archive_path: Path, extract_dir: Path) -> None:\n    \"\"\"Extract a tar archive with security validation.\n\n    Rejects archives containing symlinks, hardlinks, or paths that escape\n    the extraction directory (path traversal).\n    \"\"\"\n    base = extract_dir.resolve()\n    with tarfile.open(archive_path, \"r\") as tar:\n        members = tar.getmembers()\n        for member in members:\n            if member.issym() or member.islnk():\n                raise WeightsDownloadError(\n                    \"Unsafe symlink or hardlink found in tar archive. \"\n                    \"Archive may be corrupted or malicious.\"\n                )\n            member_path = (extract_dir / member.name).resolve()\n            if not member_path.is_relative_to(base):\n                raise WeightsDownloadError(\n                    \"Unsafe path found in tar archive (path traversal). \"\n                    \"Archive may be corrupted or malicious.\"\n                )\n        tar.extractall(path=extract_dir)\n"
  },
  {
    "path": "tinker_cookbook/weights/_export/__init__.py",
    "content": "\"\"\"Build deployable model artifacts from Tinker weights.\n\nProvides :func:`build_hf_model`, the main entry point for merging a Tinker\nLoRA adapter into a HuggingFace model. Supports multiple merge strategies:\n\n- ``\"full\"`` — loads the entire base model into memory (original behavior)\n- ``\"shard\"`` — processes one safetensors shard at a time (low memory)\n- ``\"auto\"`` (default) — uses shard-by-shard\n\nModel-specific export strategies (e.g. DeepSeek FP8) live in their own\nsubmodules and are dispatched automatically based on ``config.json``.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nimport os\nimport shutil\nfrom pathlib import Path\n\nimport torch\nfrom transformers import (\n    AutoConfig,\n    AutoProcessor,\n    AutoTokenizer,\n    PretrainedConfig,\n)\n\nfrom tinker_cookbook.exceptions import ConfigurationError\n\nlogger = logging.getLogger(__name__)\n\n# Map user-facing dtype strings to torch dtypes.\n_DTYPE_MAP: dict[str, torch.dtype] = {\n    \"bfloat16\": torch.bfloat16,\n    \"float16\": torch.float16,\n    \"float32\": torch.float32,\n}\n\n_VALID_STRATEGIES = {\"auto\", \"shard\", \"full\"}\n_VALID_QUANTIZE = {\"experts-fp8\"}\n_VALID_SERVING_FORMATS = {\"vllm\"}\n\n\ndef build_hf_model(\n    *,\n    base_model: str,\n    adapter_path: str,\n    output_path: str,\n    dtype: str = \"bfloat16\",\n    trust_remote_code: bool | None = None,\n    merge_strategy: str = \"auto\",\n    dequantize: bool = False,\n    quantize: str | None = None,\n    serving_format: str | None = None,\n) -> None:\n    \"\"\"Build a complete HuggingFace model from Tinker LoRA adapter weights.\n\n    Merges the LoRA adapter into the base model and saves the result as a\n    standard HuggingFace model directory, compatible with vLLM, SGLang, TGI,\n    or any HuggingFace-compatible inference framework.\n\n    Args:\n        base_model: HuggingFace model name (e.g. ``\"Qwen/Qwen3.5-35B-A3B\"``)\n            or local path to a saved HuggingFace model.\n        adapter_path: Local path to the Tinker adapter directory. Must contain\n            ``adapter_model.safetensors`` and ``adapter_config.json``.\n        output_path: Directory where the merged model will be saved. Must not\n            already exist.\n        dtype: Data type for loading the base model. One of ``\"bfloat16\"``\n            (default), ``\"float16\"``, or ``\"float32\"``. Use ``\"float32\"``\n            for maximum precision during merge. Only used by\n            ``merge_strategy=\"full\"``; the shard strategy preserves the\n            on-disk dtype of each tensor.\n        trust_remote_code: Whether to trust remote code when loading HF\n            models. Required for some newer model architectures (e.g.\n            Qwen3.5). If ``None`` (default), falls back to the\n            ``HF_TRUST_REMOTE_CODE`` environment variable, then ``False``.\n        merge_strategy: Controls how the merge is performed. ``\"auto\"``\n            (default) uses shard-by-shard processing for lower peak memory.\n            ``\"shard\"`` forces shard-by-shard (fails if shards can't be\n            resolved). ``\"full\"`` forces full-model loading (original\n            behavior, higher memory but simpler).\n        dequantize: If ``True``, dequantize quantized base model weights\n            before merging. Not yet implemented for the standard merge path,\n            but used internally by the quantized export path for models with\n            native FP8 weights (e.g. DeepSeek V3.1).\n        quantize: Output quantization method. Currently supported:\n            ``\"experts-fp8\"`` — quantize routed expert weights to FP8 with\n            blockwise scaling. Requires ``serving_format`` to be set.\n            ``None`` (default) — no quantization.\n        serving_format: Serving framework format for quantization metadata.\n            Currently supported: ``\"vllm\"`` — write compressed-tensors\n            config for vLLM. Required when ``quantize`` is set.\n            ``None`` (default) — no serving-specific metadata.\n\n    Raises:\n        FileNotFoundError: If adapter files are missing.\n        FileExistsError: If output_path already exists.\n        KeyError: If adapter config is malformed.\n        ValueError: If tensor shapes are incompatible during merge, or\n            if ``dtype``, ``merge_strategy``, ``quantize``, or\n            ``serving_format`` is not a recognized value, or if\n            ``quantize`` and ``serving_format`` are not both set/unset.\n        NotImplementedError: If ``dequantize=True`` on the standard merge path.\n    \"\"\"\n    # --- Validate quantize / serving_format ---\n    if quantize is not None and quantize not in _VALID_QUANTIZE:\n        raise ConfigurationError(\n            f\"Unsupported quantize={quantize!r}. Choose from: {sorted(_VALID_QUANTIZE)}\"\n        )\n    if serving_format is not None and serving_format not in _VALID_SERVING_FORMATS:\n        raise ConfigurationError(\n            f\"Unsupported serving_format={serving_format!r}. \"\n            f\"Choose from: {sorted(_VALID_SERVING_FORMATS)}\"\n        )\n    if quantize is not None and serving_format is None:\n        raise ConfigurationError(\n            f\"quantize={quantize!r} requires serving_format to be set \"\n            f\"(e.g. serving_format='vllm') to write scale metadata.\"\n        )\n    if serving_format is not None and quantize is None:\n        raise ConfigurationError(\n            f\"serving_format={serving_format!r} requires quantize to be set \"\n            f\"(e.g. quantize='experts-fp8'). Serving format without quantization is meaningless.\"\n        )\n    if quantize == \"experts-fp8\" and dtype != \"bfloat16\":\n        raise ConfigurationError(\n            f\"quantize='experts-fp8' requires dtype='bfloat16', got dtype={dtype!r}.\"\n        )\n\n    # --- Validate standard params ---\n    if dequantize and quantize is None:\n        raise NotImplementedError(\n            \"dequantize is not yet supported for the standard merge path. \"\n            \"Use quantize='experts-fp8' for models with native FP8 weights.\"\n        )\n    if dtype not in _DTYPE_MAP:\n        raise ConfigurationError(\n            f\"Unsupported dtype {dtype!r}. Choose from: {list(_DTYPE_MAP.keys())}\"\n        )\n    if merge_strategy not in _VALID_STRATEGIES:\n        raise ConfigurationError(\n            f\"Unsupported merge_strategy {merge_strategy!r}. \"\n            f\"Choose from: {sorted(_VALID_STRATEGIES)}\"\n        )\n\n    resolved_trust = resolve_trust_remote_code(trust_remote_code)\n\n    # Load model config for model-family detection (lightweight, no weight download).\n    config_dict = load_config_dict(base_model)\n\n    # --- Warn if native FP8 model without quantized export ---\n    if quantize is None and _has_native_fp8(config_dict):\n        logger.warning(\n            \"This model appears to have native FP8 weights \"\n            \"(quantization_config.quant_method='fp8'). \"\n            \"The standard merge path will apply LoRA deltas directly to FP8 tensors, \"\n            \"which may produce incorrect results due to FP8 precision loss. \"\n            \"Consider using quantize='experts-fp8' and serving_format='vllm' \"\n            \"for correct FP8-aware merging.\"\n        )\n\n    # --- Quantized export path ---\n    if quantize is not None:\n        from tinker_cookbook.weights._artifacts import resolve_model_dir\n        from tinker_cookbook.weights._export._quantized import build_quantized\n\n        model_dir = resolve_model_dir(base_model)\n        build_quantized(\n            base_model=base_model,\n            adapter_path=adapter_path,\n            output_path=output_path,\n            trust_remote_code=resolved_trust,\n            model_dir=model_dir,\n            config_dict=config_dict,\n            serving_format=serving_format,  # type: ignore[arg-type]  # validated non-None above\n        )\n        return\n\n    # --- Standard merge path ---\n    strategy = _resolve_strategy(merge_strategy)\n\n    if strategy == \"full\":\n        from tinker_cookbook.weights._export._full import build_full\n\n        build_full(\n            base_model=base_model,\n            adapter_path=adapter_path,\n            output_path=output_path,\n            dtype=dtype,\n            torch_dtype=_DTYPE_MAP[dtype],\n            trust_remote_code=resolved_trust,\n            config_dict=config_dict,\n        )\n    else:\n        if dtype != \"bfloat16\":\n            logger.warning(\n                \"dtype=%r only applies to merge_strategy='full'. \"\n                \"The shard strategy preserves each tensor's on-disk dtype. \"\n                \"Pass merge_strategy='full' to control output precision.\",\n                dtype,\n            )\n\n        from tinker_cookbook.weights._artifacts import resolve_model_dir\n        from tinker_cookbook.weights._export._shard import build_sharded\n\n        model_dir = resolve_model_dir(base_model)\n        build_sharded(\n            base_model=base_model,\n            adapter_path=adapter_path,\n            output_path=output_path,\n            trust_remote_code=resolved_trust,\n            model_dir=model_dir,\n            config_dict=config_dict,\n        )\n\n\ndef _has_native_fp8(config_dict: dict) -> bool:\n    \"\"\"Check if a model config indicates native FP8 quantization.\"\"\"\n    quant_config = config_dict.get(\"quantization_config\")\n    if not isinstance(quant_config, dict):\n        return False\n    return quant_config.get(\"quant_method\", \"\") == \"fp8\"\n\n\ndef _resolve_strategy(merge_strategy: str) -> str:\n    \"\"\"Resolve ``\"auto\"`` to a concrete strategy.\"\"\"\n    if merge_strategy == \"auto\":\n        return \"shard\"\n    return merge_strategy\n\n\n# ---------------------------------------------------------------------------\n# Shared helpers used by export strategy modules\n# ---------------------------------------------------------------------------\n\n\ndef resolve_trust_remote_code(trust_remote_code: bool | None) -> bool:\n    \"\"\"Resolve trust_remote_code from parameter or environment variable.\n\n    Priority: explicit parameter > HF_TRUST_REMOTE_CODE env var > False.\n    \"\"\"\n    if trust_remote_code is not None:\n        return trust_remote_code\n    env_val = os.environ.get(\"HF_TRUST_REMOTE_CODE\", \"\").lower()\n    return env_val in (\"1\", \"true\", \"yes\")\n\n\ndef load_config_dict(model_dir_or_name: str | Path) -> dict:\n    \"\"\"Load config.json as a raw dict from a local directory or HF model name.\n\n    For local directories, reads config.json directly. For HF model names\n    (not a local directory), falls back to ``AutoConfig.from_pretrained``.\n\n    Raises:\n        FileNotFoundError: If ``model_dir_or_name`` is a local directory\n            that doesn't contain ``config.json``.\n    \"\"\"\n    model_dir = (\n        Path(model_dir_or_name) if not isinstance(model_dir_or_name, Path) else model_dir_or_name\n    )\n    config_path = model_dir / \"config.json\"\n    if config_path.exists():\n        with open(config_path) as f:\n            return json.load(f)\n    # If it's a local directory without config.json, fail explicitly\n    if model_dir.is_dir():\n        raise FileNotFoundError(\n            f\"No config.json found in {model_dir}. \"\n            f\"Ensure this is a valid HuggingFace model directory.\"\n        )\n    # Fall back to HF config loading for remote model names\n    config = AutoConfig.from_pretrained(str(model_dir_or_name))\n    return config.to_dict()\n\n\ndef is_multimodal(config: PretrainedConfig) -> bool:\n    \"\"\"Check if a model config indicates a multimodal (e.g. vision-language) model.\"\"\"\n    multimodal_config_keys = (\"vision_config\", \"audio_config\", \"speech_config\")\n    return any(\n        hasattr(config, key) and getattr(config, key) is not None for key in multimodal_config_keys\n    )\n\n\ndef is_multimodal_from_dict(config_dict: dict) -> bool:\n    \"\"\"Check if a raw config dict indicates a multimodal model.\"\"\"\n    multimodal_keys = (\"vision_config\", \"audio_config\", \"speech_config\")\n    return any(config_dict.get(key) is not None for key in multimodal_keys)\n\n\ndef save_tokenizer_and_processor(\n    base_model: str,\n    output_path: Path,\n    multimodal: bool,\n    trust_remote_code: bool,\n) -> None:\n    \"\"\"Save tokenizer and optional processor to the output directory.\"\"\"\n    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code)\n    tokenizer.save_pretrained(output_path)\n\n    if multimodal:\n        try:\n            processor = AutoProcessor.from_pretrained(\n                base_model, trust_remote_code=trust_remote_code\n            )\n            processor.save_pretrained(output_path)\n        except (OSError, ValueError) as e:\n            logger.warning(\n                \"Could not load processor for vision model %s: %s. \"\n                \"You may need to copy the processor files manually.\",\n                base_model,\n                e,\n            )\n\n\ndef cleanup_on_failure(out: Path) -> None:\n    \"\"\"Clean up partial output so the user can retry without manual deletion.\"\"\"\n    try:\n        if out.exists():\n            shutil.rmtree(out)\n    except OSError:\n        logger.warning(\"Failed to clean up partial output at %s\", out)\n"
  },
  {
    "path": "tinker_cookbook/weights/_export/_full.py",
    "content": "\"\"\"Full-model export strategy.\n\nLoads the entire base model into memory, merges LoRA adapter weights in-place,\nand saves via ``model.save_pretrained()``. This is the original merge behavior\nand serves as the fallback when shard-by-shard processing isn't suitable.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom pathlib import Path\n\nimport torch\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoModelForImageTextToText,\n    PretrainedConfig,\n    PreTrainedModel,\n)\n\nfrom tinker_cookbook.weights._artifacts import load_adapter_weights\nfrom tinker_cookbook.weights._export import (\n    cleanup_on_failure,\n    is_multimodal,\n    is_multimodal_from_dict,\n    save_tokenizer_and_processor,\n)\nfrom tinker_cookbook.weights._merge import merge_adapter_weights\n\nlogger = logging.getLogger(__name__)\n\n\ndef build_full(\n    *,\n    base_model: str,\n    adapter_path: str,\n    output_path: str,\n    dtype: str,\n    torch_dtype: torch.dtype,\n    trust_remote_code: bool,\n    config_dict: dict,\n) -> None:\n    \"\"\"Merge by loading the entire base model into memory.\n\n    Args:\n        base_model: HuggingFace model name or local path.\n        adapter_path: Path to adapter directory.\n        output_path: Where to write the merged model.\n        dtype: String dtype name (for logging).\n        torch_dtype: Torch dtype for model loading.\n        trust_remote_code: Whether to trust remote code for HF loading.\n        config_dict: Parsed config.json dict (loaded by dispatcher).\n    \"\"\"\n    # Fail fast if output already exists (before any expensive work)\n    out = Path(output_path)\n    if out.exists():\n        raise FileExistsError(f\"Output path already exists: {out}\")\n\n    # Validate adapter exists before loading the (potentially huge) base model\n    adapter_weights, adapter_config = load_adapter_weights(Path(adapter_path))\n\n    out.mkdir(parents=True, exist_ok=False)\n\n    try:\n        logger.info(\"Loading base model: %s (dtype=%s)\", base_model, dtype)\n        config = AutoConfig.from_pretrained(base_model, trust_remote_code=trust_remote_code)\n        hf_model = _load_model(\n            config, base_model, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code\n        )\n\n        logger.info(\"Merging adapter weights\")\n        merge_adapter_weights(hf_model, adapter_weights, adapter_config)\n\n        logger.info(\"Saving merged model to: %s\", out)\n        hf_model.save_pretrained(out)\n\n        save_tokenizer_and_processor(\n            base_model, out, is_multimodal_from_dict(config_dict), trust_remote_code\n        )\n\n        logger.info(\"Done — merged model saved to %s\", out)\n    except Exception:\n        cleanup_on_failure(out)\n        raise\n\n\ndef _load_model(\n    config: PretrainedConfig,\n    model_path: str,\n    *,\n    torch_dtype: torch.dtype,\n    trust_remote_code: bool,\n) -> PreTrainedModel:\n    auto_cls = AutoModelForImageTextToText if is_multimodal(config) else AutoModelForCausalLM\n    return auto_cls.from_pretrained(\n        model_path, dtype=torch_dtype, trust_remote_code=trust_remote_code\n    )\n"
  },
  {
    "path": "tinker_cookbook/weights/_export/_quantized.py",
    "content": "\"\"\"Quantized export strategy.\n\nMerges LoRA adapters shard-by-shard and quantizes routed expert weights to FP8.\nProduces output compatible with vLLM's compressed-tensors format.\n\nCurrently supports DeepSeek V3/V3.1 models. The infrastructure (FP8 math, vLLM\nconfig generation, resume support) is reusable for future model families.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nfrom collections.abc import Callable\nfrom pathlib import Path\n\nimport torch\nfrom safetensors import safe_open\nfrom safetensors.torch import load_file, save_file\n\nfrom tinker_cookbook.exceptions import WeightsMergeError\nfrom tinker_cookbook.weights._artifacts import (\n    copy_artifact_file,\n    copy_model_code_files,\n    get_model_state_shapes,\n    get_shard_files,\n    load_adapter_weights,\n)\nfrom tinker_cookbook.weights._export import (\n    is_multimodal_from_dict,\n    save_tokenizer_and_processor,\n)\nfrom tinker_cookbook.weights._merge import (\n    apply_merge_op,\n    detect_merge_profile,\n    plan_merge_ops,\n    validate_merge_op_shapes,\n)\n\nlogger = logging.getLogger(__name__)\n\n# ---------------------------------------------------------------------------\n# DeepSeek detection\n# ---------------------------------------------------------------------------\n\n_DEEPSEEK_MODEL_TYPES = frozenset({\"deepseek_v3\"})\n\n\ndef is_deepseek_config(config_dict: dict) -> bool:\n    \"\"\"Check if config describes a DeepSeek model family.\"\"\"\n    return config_dict.get(\"model_type\") in _DEEPSEEK_MODEL_TYPES\n\n\n# ---------------------------------------------------------------------------\n# FP8 blockwise quantization\n# ---------------------------------------------------------------------------\n\n# DeepSeek V3/V3.1 native FP8 block size\n_FP8_BLOCK_SIZE = 128\n\n\ndef _get_fp8_max() -> float:\n    \"\"\"Get max representable value in float8_e4m3fn, with fallback for older PyTorch.\"\"\"\n    try:\n        return float(torch.finfo(torch.float8_e4m3fn).max)\n    except TypeError:\n        return 448.0\n\n\n_FP8_MAX = _get_fp8_max()\n\n\ndef quantize_blockwise(\n    tensor: torch.Tensor,\n    block_size: tuple[int, int] = (_FP8_BLOCK_SIZE, _FP8_BLOCK_SIZE),\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Quantize a 2D tensor to FP8 using blockwise scaling.\n\n    Divides the tensor into blocks, computes a per-block scale factor, and\n    quantizes each block to float8_e4m3fn.\n\n    Args:\n        tensor: 2D float tensor to quantize.\n        block_size: (row_block, col_block) sizes. Tensor is padded if dimensions\n            are not evenly divisible.\n\n    Returns:\n        Tuple of (quantized_fp8, scale_inv) where:\n        - quantized_fp8: float8_e4m3fn tensor, same shape as input\n        - scale_inv: float32 tensor of shape (ceil(rows/row_block), ceil(cols/col_block))\n    \"\"\"\n    assert tensor.ndim == 2, f\"Expected 2D tensor, got {tensor.ndim}D\"\n    rows, cols = tensor.shape\n    rb, cb = block_size\n\n    # Pad to block boundaries\n    pad_rows = (rb - rows % rb) % rb\n    pad_cols = (cb - cols % cb) % cb\n    if pad_rows > 0 or pad_cols > 0:\n        padded = torch.zeros(\n            rows + pad_rows, cols + pad_cols, dtype=tensor.dtype, device=tensor.device\n        )\n        padded[:rows, :cols] = tensor\n    else:\n        padded = tensor\n\n    # Reshape into blocks\n    pr, pc = padded.shape\n    blocks = padded.reshape(pr // rb, rb, pc // cb, cb).permute(0, 2, 1, 3)\n\n    # Per-block max for scale computation\n    block_max = blocks.abs().reshape(blocks.shape[0], blocks.shape[1], -1).max(dim=-1).values\n    # Avoid division by zero\n    block_max = block_max.clamp(min=1e-12)\n\n    scale = block_max / _FP8_MAX\n    scale_inv = scale  # scale_inv[i,j] = max_val / FP8_MAX\n\n    # Quantize: scale each block, clamp, cast\n    inv_scale = 1.0 / scale.unsqueeze(-1).unsqueeze(-1)  # broadcast over block dims\n    scaled_blocks = blocks.float() * inv_scale\n    clamped = scaled_blocks.clamp(-_FP8_MAX, _FP8_MAX)\n\n    # Reshape back to padded shape\n    quantized_padded = clamped.permute(0, 2, 1, 3).reshape(pr, pc)\n\n    # Trim padding\n    quantized = quantized_padded[:rows, :cols].to(torch.float8_e4m3fn)\n\n    return quantized, scale_inv.float()\n\n\ndef dequantize_blockwise(\n    quantized: torch.Tensor,\n    scale_inv: torch.Tensor,\n    block_size: tuple[int, int] = (_FP8_BLOCK_SIZE, _FP8_BLOCK_SIZE),\n    dtype: torch.dtype = torch.bfloat16,\n) -> torch.Tensor:\n    \"\"\"Dequantize an FP8 tensor back to float using blockwise scales.\n\n    Args:\n        quantized: float8_e4m3fn tensor.\n        scale_inv: float32 scale tensor from :func:`quantize_blockwise`.\n        block_size: Must match the block_size used during quantization.\n        dtype: Output dtype.\n\n    Returns:\n        Dequantized tensor in the requested dtype.\n    \"\"\"\n    assert quantized.ndim == 2, f\"Expected 2D tensor, got {quantized.ndim}D\"\n    rows, cols = quantized.shape\n    rb, cb = block_size\n\n    # Pad to block boundaries\n    pad_rows = (rb - rows % rb) % rb\n    pad_cols = (cb - cols % cb) % cb\n    if pad_rows > 0 or pad_cols > 0:\n        padded = torch.zeros(\n            rows + pad_rows, cols + pad_cols, dtype=torch.float32, device=quantized.device\n        )\n        padded[:rows, :cols] = quantized.float()\n    else:\n        padded = quantized.float()\n\n    # Reshape into blocks\n    pr, pc = padded.shape\n    blocks = padded.reshape(pr // rb, rb, pc // cb, cb).permute(0, 2, 1, 3)\n\n    # Multiply by scale\n    blocks = blocks * scale_inv.unsqueeze(-1).unsqueeze(-1)\n\n    # Reshape back\n    result = blocks.permute(0, 2, 1, 3).reshape(pr, pc)\n    return result[:rows, :cols].to(dtype)\n\n\n# ---------------------------------------------------------------------------\n# Weight classification\n# ---------------------------------------------------------------------------\n\n# Pattern for routed expert weights in DeepSeek models\n# e.g. \"model.layers.3.mlp.experts.42.gate_proj.weight\"\n_ROUTED_EXPERT_PATTERN = \".mlp.experts.\"\n_SHARED_EXPERT_PATTERN = \".mlp.shared_experts.\"\n\n\ndef _is_routed_expert_weight(key: str) -> bool:\n    \"\"\"Check if a weight key belongs to a routed (non-shared) expert.\"\"\"\n    return _ROUTED_EXPERT_PATTERN in key and _SHARED_EXPERT_PATTERN not in key\n\n\n# ---------------------------------------------------------------------------\n# Keys to skip in DeepSeek checkpoints\n# ---------------------------------------------------------------------------\n\n# DeepSeek has some keys that should not be part of the merge:\n# - Layer 61 is a placeholder/unused layer in some checkpoints\n# - rotary_emb inverse frequency is derived, not a trained parameter\n_SKIP_SUFFIXES = (\".rotary_emb.inv_freq\",)\n_SKIP_LAYER_INDICES = frozenset({61})\n\n\ndef _should_skip_checkpoint_key(key: str) -> bool:\n    \"\"\"Check if a checkpoint key should be excluded from merge planning.\"\"\"\n    if any(key.endswith(s) for s in _SKIP_SUFFIXES):\n        return True\n    # Check for layer 61 (DeepSeek-specific)\n    parts = key.split(\".\")\n    for i, part in enumerate(parts):\n        if part == \"layers\" and i + 1 < len(parts):\n            try:\n                layer_idx = int(parts[i + 1])\n                if layer_idx in _SKIP_LAYER_INDICES:\n                    return True\n            except ValueError:\n                pass\n    return False\n\n\n# ---------------------------------------------------------------------------\n# Native FP8 checkpoint handling\n# ---------------------------------------------------------------------------\n\n\ndef _has_native_fp8_quantization(config_dict: dict) -> bool:\n    \"\"\"Check if the model checkpoint uses native FP8 quantization.\n\n    DeepSeek V3.1 checkpoints can ship with native FP8 weights and\n    ``quantization_config.quant_method == \"fp8\"``. These need to be\n    dequantized before re-quantizing with our own scales.\n    \"\"\"\n    quant_config = config_dict.get(\"quantization_config\")\n    if quant_config is None:\n        return False\n    if isinstance(quant_config, dict):\n        return quant_config.get(\"quant_method\", \"\") == \"fp8\"\n    return False\n\n\ndef _get_native_block_size(config_dict: dict) -> tuple[int, int]:\n    \"\"\"Get the FP8 block size from the model's native quantization config.\n\n    Falls back to the standard DeepSeek block size (128, 128) if not specified.\n    \"\"\"\n    quant_config = config_dict.get(\"quantization_config\", {})\n    if isinstance(quant_config, dict):\n        block_size = quant_config.get(\"weight_block_size\", [_FP8_BLOCK_SIZE, _FP8_BLOCK_SIZE])\n        return (int(block_size[0]), int(block_size[1]))\n    return (_FP8_BLOCK_SIZE, _FP8_BLOCK_SIZE)\n\n\ndef _make_cross_shard_tensor_loader(\n    model_dir: Path,\n) -> Callable[[str], torch.Tensor]:\n    \"\"\"Create a loader that can fetch a single tensor from any shard by key name.\n\n    Used when a weight tensor and its scale are in different shards. Reads\n    the safetensors index to find which shard contains a given key, then\n    uses ``safe_open`` to load only that one tensor — no full shard loading.\n\n    This keeps peak memory at O(single tensor) rather than O(full shard),\n    which matters for DeepSeek V3 where shards are ~4-5 GB each.\n    \"\"\"\n    # Build key → shard mapping from index\n    index_path = model_dir / \"model.safetensors.index.json\"\n    if index_path.exists():\n        with open(index_path) as f:\n            index_weight_map: dict[str, str] = json.load(f)[\"weight_map\"]\n    else:\n        # Single shard — build map from the one file\n        shard_files = sorted(model_dir.glob(\"*.safetensors\"))\n        index_weight_map = {}\n        for sf_path in shard_files:\n            with safe_open(str(sf_path), framework=\"pt\") as f:\n                for key in f.keys():  # noqa: SIM118\n                    index_weight_map[key] = sf_path.name\n\n    def load_tensor(name: str) -> torch.Tensor:\n        if name not in index_weight_map:\n            raise KeyError(f\"Tensor {name!r} not found in any shard at {model_dir}\")\n        shard_name = index_weight_map[name]\n        with safe_open(str(model_dir / shard_name), framework=\"pt\") as f:\n            return f.get_tensor(name)\n\n    return load_tensor\n\n\n# ---------------------------------------------------------------------------\n# vLLM compressed-tensors config\n# ---------------------------------------------------------------------------\n\n\ndef _weight_scale_key(weight_key: str) -> str:\n    \"\"\"Map a weight key to its compressed-tensors scale key.\n\n    Uses ``.weight_scale`` (compressed-tensors convention), NOT\n    ``.weight_scale_inv`` (DeepSeek native convention).\n    \"\"\"\n    return weight_key.removesuffix(\".weight\") + \".weight_scale\"\n\n\n# Linear projection suffixes used to build the compressed-tensors ignore list.\n# Only modules matching these suffixes are considered for the ignore list.\n_LINEAR_PROJ_SUFFIXES = (\n    \".q_proj.weight\",\n    \".q_a_proj.weight\",\n    \".q_b_proj.weight\",\n    \".kv_a_proj_with_mqa.weight\",\n    \".kv_b_proj.weight\",\n    \".o_proj.weight\",\n    \".gate_proj.weight\",\n    \".up_proj.weight\",\n    \".down_proj.weight\",\n)\n\n\ndef _build_vllm_quantization_config(output_weight_map: dict[str, str]) -> dict:\n    \"\"\"Build compressed-tensors quantization config for vLLM.\n\n    Produces a config dict that tells vLLM which layers are FP8-quantized\n    (routed experts) and which to ignore (everything else). No library\n    imports needed — the schema is fixed and well-known.\n\n    Args:\n        output_weight_map: Mapping of weight key -> shard filename.\n\n    Returns:\n        Dict suitable for config.json's ``compression_config`` field.\n    \"\"\"\n    # Determine which modules have been quantized (have .weight_scale)\n    quantized_prefixes = {\n        key.removesuffix(\".weight_scale\")\n        for key in output_weight_map\n        if key.endswith(\".weight_scale\")\n    }\n\n    # Build ignore list: linear projection modules that were NOT quantized\n    ignore: list[str] = []\n    for key in sorted(output_weight_map):\n        if not any(key.endswith(suffix) for suffix in _LINEAR_PROJ_SUFFIXES):\n            continue\n        prefix = key.removesuffix(\".weight\")\n        if prefix not in quantized_prefixes:\n            ignore.append(prefix)\n\n    # Also ignore lm_head if present and not quantized\n    if \"lm_head.weight\" in output_weight_map and \"lm_head\" not in quantized_prefixes:\n        ignore.append(\"lm_head\")\n\n    return {\n        \"quant_method\": \"compressed-tensors\",\n        \"format\": \"float-quantized\",\n        \"quantization_status\": \"compressed\",\n        \"global_compression_ratio\": None,\n        \"config_groups\": {\n            \"group_0\": {\n                \"targets\": [\"Linear\"],\n                \"weights\": {\n                    \"num_bits\": 8,\n                    \"type\": \"float\",\n                    \"symmetric\": True,\n                    \"strategy\": \"block\",\n                    \"block_structure\": [_FP8_BLOCK_SIZE, _FP8_BLOCK_SIZE],\n                    \"dynamic\": False,\n                },\n                \"input_activations\": {\n                    \"num_bits\": 8,\n                    \"type\": \"float\",\n                    \"symmetric\": True,\n                    \"strategy\": \"tensor\",\n                    \"dynamic\": True,\n                },\n            },\n        },\n        \"ignore\": ignore,\n    }\n\n\n_VLLM_COMPAT_QUANT_CONFIG_FIELDS = {\n    \"config_groups\",\n    \"format\",\n    \"global_compression_ratio\",\n    \"ignore\",\n    \"kv_cache_scheme\",\n    \"quantization_status\",\n}\n_VLLM_COMPAT_QUANT_SCHEME_FIELDS = {\n    \"format\",\n    \"input_activations\",\n    \"output_activations\",\n    \"targets\",\n    \"weights\",\n}\n_VLLM_COMPAT_QUANT_ARGS_FIELDS = {\n    \"actorder\",\n    \"block_structure\",\n    \"dynamic\",\n    \"group_size\",\n    \"num_bits\",\n    \"observer\",\n    \"observer_kwargs\",\n    \"strategy\",\n    \"symmetric\",\n    \"type\",\n}\n\n\ndef _serialize_for_vllm(config: dict) -> dict:\n    \"\"\"Serialize only the compressed-tensors fields the current vLLM path needs.\n\n    Uses an allowlist so new compressed-tensors fields are omitted automatically\n    instead of breaking older vLLM builds.\n    \"\"\"\n    serialized: dict = {}\n    for key, value in config.items():\n        if key == \"config_groups\" and isinstance(value, dict):\n            serialized[key] = {\n                group_name: _serialize_vllm_scheme(group)\n                for group_name, group in value.items()\n                if isinstance(group, dict)\n            }\n            continue\n        if key in _VLLM_COMPAT_QUANT_CONFIG_FIELDS:\n            serialized[key] = value\n    serialized[\"quant_method\"] = \"compressed-tensors\"\n    return serialized\n\n\ndef _serialize_vllm_scheme(group: dict) -> dict:\n    \"\"\"Serialize a single quantization scheme for vLLM compatibility.\"\"\"\n    serialized: dict = {}\n    for key, value in group.items():\n        if key in {\"weights\", \"input_activations\", \"output_activations\"} and isinstance(\n            value, dict\n        ):\n            serialized[key] = {\n                field: field_value\n                for field, field_value in value.items()\n                if field in _VLLM_COMPAT_QUANT_ARGS_FIELDS\n            }\n            continue\n        if key in _VLLM_COMPAT_QUANT_SCHEME_FIELDS:\n            serialized[key] = value\n    return serialized\n\n\n# ---------------------------------------------------------------------------\n# Resume state management\n# ---------------------------------------------------------------------------\n\n_MERGE_STATE_FILE = \"merge_state.json\"\n\n\ndef _load_resume_state(output_path: Path) -> dict:\n    \"\"\"Load resume state from a previous incomplete run.\n\n    Returns:\n        Dict with keys: ``status``, ``completed_shards`` (list of filenames),\n        ``total_shards``. Returns empty dict if no state file exists.\n    \"\"\"\n    state_file = output_path / _MERGE_STATE_FILE\n    if not state_file.exists():\n        return {}\n    with open(state_file) as f:\n        state = json.load(f)\n\n    # Validate: every completed shard file must exist\n    completed = state.get(\"completed_shards\", [])\n    for shard_name in completed:\n        if not (output_path / shard_name).exists():\n            raise WeightsMergeError(\n                f\"Resume state references {shard_name!r} but file not found in {output_path}. \"\n                f\"Delete {output_path} and restart.\"\n            )\n    return state\n\n\ndef _save_merge_state(\n    output_path: Path,\n    *,\n    status: str,\n    completed_shards: list[str],\n    total_shards: int,\n) -> None:\n    \"\"\"Save merge state atomically for resume support.\"\"\"\n    state = {\n        \"status\": status,\n        \"completed_shards\": completed_shards,\n        \"total_shards\": total_shards,\n    }\n    tmp = output_path / f\"{_MERGE_STATE_FILE}.tmp\"\n    with open(tmp, \"w\") as f:\n        json.dump(state, f, indent=2)\n    tmp.rename(output_path / _MERGE_STATE_FILE)\n\n\ndef _save_shard_atomic(\n    output_path: Path, shard_name: str, tensors: dict[str, torch.Tensor]\n) -> None:\n    \"\"\"Save a shard file atomically (write to temp, then rename).\"\"\"\n    tmp_name = f\"{shard_name}.tmp\"\n    save_file(tensors, str(output_path / tmp_name))\n    (output_path / tmp_name).rename(output_path / shard_name)\n\n\n# ---------------------------------------------------------------------------\n# Main entry point\n# ---------------------------------------------------------------------------\n\n\ndef build_quantized(\n    *,\n    base_model: str,\n    adapter_path: str,\n    output_path: str,\n    trust_remote_code: bool,\n    model_dir: Path,\n    config_dict: dict,\n    serving_format: str,\n) -> None:\n    \"\"\"Merge LoRA adapter and quantize routed experts to FP8.\n\n    Processes one safetensors shard at a time:\n    1. Load shard tensors\n    2. Apply any LoRA merge ops targeting this shard\n    3. Quantize routed expert weights to FP8 with blockwise scales\n    4. Preserve dense/shared-expert weights in BF16\n    5. Write output shard (preserving input shard layout)\n    6. Track progress for resume support\n\n    After all shards are processed:\n    - Write safetensors index\n    - Patch config.json with compressed-tensors metadata\n    - Copy tokenizer, model code\n\n    Args:\n        base_model: Model name or path (for tokenizer loading).\n        adapter_path: Path to adapter directory.\n        output_path: Where to write the quantized model.\n        trust_remote_code: Whether to trust remote code.\n        model_dir: Resolved local model directory.\n        config_dict: Parsed config.json dict.\n        serving_format: Serving framework format (e.g. \"vllm\").\n    \"\"\"\n    out = Path(output_path)\n\n    # Check for resume\n    resume_state = {}\n    if out.exists():\n        resume_state = _load_resume_state(out)\n        if not resume_state:\n            raise FileExistsError(f\"Output path already exists: {out}\")\n        if resume_state.get(\"status\") == \"completed\":\n            logger.info(\"Output already complete at %s, skipping\", out)\n            return\n        logger.info(\n            \"Resuming: %d/%d shards completed\",\n            len(resume_state.get(\"completed_shards\", [])),\n            resume_state.get(\"total_shards\", \"?\"),\n        )\n    else:\n        out.mkdir(parents=True, exist_ok=False)\n\n    # 1. Load adapter\n    adapter_weights, adapter_config = load_adapter_weights(Path(adapter_path))\n\n    # 2. Read model metadata\n    model_shapes = get_model_state_shapes(model_dir)\n    model_state_keys = set(model_shapes.keys())\n\n    # Pre-filter keys that DeepSeek checkpoints include but shouldn't be merged.\n    # Also exclude .weight_scale_inv — these are native FP8 scales, not merge targets.\n    filtered_keys = {\n        k\n        for k in model_state_keys\n        if not _should_skip_checkpoint_key(k) and not k.endswith(\".weight_scale_inv\")\n    }\n\n    # 3. Detect merge profile and plan ops\n    profile = detect_merge_profile(config_dict, model_state_keys)\n    logger.info(\n        \"Detected merge profile: expert_layout=%s, language_model_prefix=%s\",\n        profile.expert_layout,\n        profile.has_language_model_prefix,\n    )\n\n    merge_ops = plan_merge_ops(adapter_weights, adapter_config, filtered_keys, profile)\n    total_ops = sum(len(ops) for ops in merge_ops.values())\n    logger.info(\"Planned %d merge operations across %d target keys\", total_ops, len(merge_ops))\n\n    # Validate shapes against filtered keys\n    filtered_shapes = {k: v for k, v in model_shapes.items() if k in filtered_keys}\n    validate_merge_op_shapes(merge_ops, filtered_shapes)\n\n    # 4. Set up native FP8 handling (cross-shard scale lookup)\n    is_native_fp8 = _has_native_fp8_quantization(config_dict)\n    native_block_size = _get_native_block_size(config_dict) if is_native_fp8 else None\n    cross_shard_loader = _make_cross_shard_tensor_loader(model_dir) if is_native_fp8 else None\n    if is_native_fp8:\n        logger.info(\n            \"Native FP8 checkpoint detected (block_size=%s), will dequantize before re-quantize\",\n            native_block_size,\n        )\n\n    # 5. Process shards\n    shard_files = get_shard_files(model_dir)\n    completed_shards = set(resume_state.get(\"completed_shards\", []))\n    all_completed: list[str] = list(completed_shards)\n    weight_map: dict[str, str] = {}\n\n    # Rebuild weight map from already-completed shards\n    for shard_name in completed_shards:\n        shard_tensors = load_file(str(out / shard_name))\n        for key in shard_tensors:\n            weight_map[key] = shard_name\n        # Pop merge ops for completed shard keys\n        for key in shard_tensors:\n            merge_ops.pop(key, None)\n        del shard_tensors\n\n    logger.info(\n        \"Processing %d input shard(s) (%d already completed)\",\n        len(shard_files),\n        len(completed_shards),\n    )\n    ops_applied = 0\n\n    for i, shard_file in enumerate(shard_files):\n        # Determine output shard name (preserve input naming)\n        out_shard_name = shard_file\n\n        if out_shard_name in completed_shards:\n            logger.info(\"Skipping completed shard %d/%d: %s\", i + 1, len(shard_files), shard_file)\n            continue\n\n        logger.info(\"Processing shard %d/%d: %s\", i + 1, len(shard_files), shard_file)\n        tensors = load_file(str(model_dir / shard_file))\n        output_tensors: dict[str, torch.Tensor] = {}\n\n        for key in list(tensors.keys()):\n            tensor = tensors[key]\n\n            # Skip keys that shouldn't be in output\n            if _should_skip_checkpoint_key(key):\n                continue\n\n            # Skip native scale_inv tensors (we generate new .weight_scale)\n            if key.endswith(\".weight_scale_inv\"):\n                continue\n\n            # Step 1: Dequantize native FP8 weights BEFORE merge\n            # Native FP8 checkpoints store weights in FP8 + scale_inv.\n            # We must dequantize to BF16 first so the LoRA merge math works\n            # correctly in float precision.\n            if key.endswith(\".weight\") and tensor.dtype == torch.float8_e4m3fn and is_native_fp8:\n                scale_key = key.replace(\".weight\", \".weight_scale_inv\")\n                # Scale may be in this shard or a different one\n                scale_inv = tensors.get(scale_key)\n                if scale_inv is None and cross_shard_loader is not None:\n                    scale_inv = cross_shard_loader(scale_key)\n                if scale_inv is not None:\n                    assert native_block_size is not None\n                    tensor = dequantize_blockwise(tensor, scale_inv, block_size=native_block_size)\n                else:\n                    raise WeightsMergeError(\n                        f\"Native FP8 weight {key!r} has no .weight_scale_inv tensor \"\n                        f\"in any shard. Cannot dequantize for merge.\"\n                    )\n\n            # Step 2: Apply LoRA merge ops (on dequantized BF16 tensors)\n            ops_for_key = merge_ops.pop(key, [])\n            if ops_for_key:\n                temp = {key: tensor}\n                for op in ops_for_key:\n                    apply_merge_op(temp, op)\n                    ops_applied += 1\n                tensor = temp[key]\n\n            # Step 3: Quantize routed experts to FP8, preserve everything else\n            if _is_routed_expert_weight(key) and key.endswith(\".weight\"):\n                fp8_tensor, scale = quantize_blockwise(tensor)\n                output_tensors[key] = fp8_tensor\n                output_tensors[_weight_scale_key(key)] = scale\n            else:\n                output_tensors[key] = tensor\n\n            weight_map[key] = out_shard_name\n            # Also track scale tensors in weight map\n            scale_out_key = _weight_scale_key(key) if key.endswith(\".weight\") else None\n            if (\n                scale_out_key\n                and scale_out_key in output_tensors\n                and scale_out_key not in weight_map\n            ):\n                weight_map[scale_out_key] = out_shard_name\n\n        del tensors\n\n        # Save shard atomically\n        _save_shard_atomic(out, out_shard_name, output_tensors)\n        del output_tensors\n\n        all_completed.append(out_shard_name)\n        _save_merge_state(\n            out,\n            status=\"in_progress\",\n            completed_shards=all_completed,\n            total_shards=len(shard_files),\n        )\n\n    # Verify all merge ops were consumed\n    if merge_ops:\n        unconsumed = list(merge_ops.keys())\n        raise WeightsMergeError(\n            f\"Merge ops not applied — {len(unconsumed)} target keys not found in any shard: \"\n            f\"{unconsumed[:5]}{'...' if len(unconsumed) > 5 else ''}\"\n        )\n\n    logger.info(\"Applied %d/%d merge operations\", ops_applied, total_ops)\n\n    # 6. Write index\n    shard_names = set(weight_map.values())\n    index = {\n        \"metadata\": {\"total_size\": _compute_total_size(out, shard_names)},\n        \"weight_map\": dict(sorted(weight_map.items())),\n    }\n    index_path = out / \"model.safetensors.index.json\"\n    with open(index_path, \"w\") as f:\n        json.dump(index, f, indent=2)\n\n    # 7. Copy config and patch with quantization metadata\n    src_config = model_dir / \"config.json\"\n    if src_config.exists():\n        copy_artifact_file(src_config, out / \"config.json\")\n\n    if serving_format == \"vllm\":\n        quant_config = _build_vllm_quantization_config(weight_map)\n        _patch_config_with_quantization(out, quant_config)\n\n    # 7. Copy model code and tokenizer\n    copy_model_code_files(model_dir, out)\n    save_tokenizer_and_processor(\n        base_model, out, is_multimodal_from_dict(config_dict), trust_remote_code\n    )\n\n    # 8. Mark complete\n    _save_merge_state(\n        out,\n        status=\"completed\",\n        completed_shards=all_completed,\n        total_shards=len(shard_files),\n    )\n\n    logger.info(\"Done — quantized model saved to %s\", out)\n\n\n_DTYPE_SIZES: dict[str, int] = {\n    \"F64\": 8,\n    \"F32\": 4,\n    \"F16\": 2,\n    \"BF16\": 2,\n    \"I64\": 8,\n    \"I32\": 4,\n    \"I16\": 2,\n    \"I8\": 1,\n    \"U8\": 1,\n    \"F8_E4M3\": 1,\n    \"F8_E5M2\": 1,\n    \"BOOL\": 1,\n}\n\n\ndef _compute_total_size(output_path: Path, shard_names: set[str]) -> int:\n    \"\"\"Compute total byte size of all tensors across output shards.\n\n    Reads safetensors headers only (shape + dtype) without loading tensor data,\n    matching the HuggingFace convention for ``model.safetensors.index.json``.\n    \"\"\"\n    total = 0\n    for name in shard_names:\n        shard_path = output_path / name\n        if not shard_path.exists():\n            continue\n        with safe_open(str(shard_path), framework=\"pt\") as f:\n            for key in f.keys():  # noqa: SIM118\n                shape = f.get_slice(key).get_shape()\n                dtype_str = f.get_slice(key).get_dtype()\n                numel = 1\n                for dim in shape:\n                    numel *= dim\n                total += numel * _DTYPE_SIZES.get(dtype_str, 4)\n    return total\n\n\ndef _patch_config_with_quantization(output_path: Path, quant_config: dict) -> None:\n    \"\"\"Patch config.json with compressed-tensors quantization metadata.\n\n    Adds ``compression_config`` and removes ``quantization_config`` (which\n    refers to the input model's native quantization, not our output).\n    \"\"\"\n    config_path = output_path / \"config.json\"\n    with open(config_path) as f:\n        config = json.load(f)\n\n    config[\"compression_config\"] = _serialize_for_vllm(quant_config)\n    config.pop(\"quantization_config\", None)\n\n    with open(config_path, \"w\") as f:\n        json.dump(config, f, indent=2)\n"
  },
  {
    "path": "tinker_cookbook/weights/_export/_shard.py",
    "content": "\"\"\"Shard-by-shard export strategy.\n\nProcesses one safetensors shard at a time, keeping peak memory proportional to\nthe largest shard rather than the full model. Produces output identical to the\nfull-model path.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nfrom pathlib import Path\n\nfrom safetensors.torch import load_file\n\nfrom tinker_cookbook.exceptions import WeightsMergeError\nfrom tinker_cookbook.weights._artifacts import (\n    ShardWriter,\n    copy_artifact_file,\n    copy_model_code_files,\n    get_model_state_shapes,\n    get_shard_files,\n    load_adapter_weights,\n)\nfrom tinker_cookbook.weights._export import (\n    cleanup_on_failure,\n    is_multimodal_from_dict,\n    save_tokenizer_and_processor,\n)\nfrom tinker_cookbook.weights._merge import (\n    apply_merge_op,\n    detect_merge_profile,\n    plan_merge_ops,\n    validate_merge_op_shapes,\n)\n\nlogger = logging.getLogger(__name__)\n\n\ndef build_sharded(\n    *,\n    base_model: str,\n    adapter_path: str,\n    output_path: str,\n    trust_remote_code: bool,\n    model_dir: Path,\n    config_dict: dict,\n) -> None:\n    \"\"\"Merge by processing one safetensors shard at a time.\n\n    Args:\n        base_model: Original model name (used for tokenizer loading).\n        adapter_path: Path to adapter directory.\n        output_path: Where to write the merged model.\n        trust_remote_code: Whether to trust remote code for HF loading.\n        model_dir: Resolved local directory containing model files.\n        config_dict: Parsed config.json dict (loaded by dispatcher).\n    \"\"\"\n    # 0. Fail fast if output already exists (before any expensive work)\n    out = Path(output_path)\n    if out.exists():\n        raise FileExistsError(f\"Output path already exists: {out}\")\n\n    # 1. Load adapter (small — only LoRA matrices)\n    adapter_weights, adapter_config = load_adapter_weights(Path(adapter_path))\n\n    # 2. Read model state shapes from safetensors headers (no weight loading)\n    model_shapes = get_model_state_shapes(model_dir)\n    model_state_keys = set(model_shapes.keys())\n\n    # 3. Detect model-specific merge profile from config + key names\n    profile = detect_merge_profile(config_dict, model_state_keys)\n    logger.info(\n        \"Detected merge profile: expert_layout=%s, language_model_prefix=%s\",\n        profile.expert_layout,\n        profile.has_language_model_prefix,\n    )\n\n    # 4. Plan all merge ops (validates keys before any heavy I/O)\n    merge_ops = plan_merge_ops(adapter_weights, adapter_config, model_state_keys, profile)\n    total_ops = sum(len(ops) for ops in merge_ops.values())\n    logger.info(\"Planned %d merge operations across %d target keys\", total_ops, len(merge_ops))\n\n    # 5. Validate shapes upfront (catches mismatches before loading any shards)\n    validate_merge_op_shapes(merge_ops, model_shapes)\n\n    # 6. Process shards\n    out.mkdir(parents=True, exist_ok=False)\n\n    try:\n        shard_files = get_shard_files(model_dir)\n        logger.info(\"Processing %d input shard(s)\", len(shard_files))\n\n        writer = ShardWriter(out)\n        ops_applied = 0\n\n        for i, shard_file in enumerate(shard_files):\n            logger.info(\"Processing shard %d/%d: %s\", i + 1, len(shard_files), shard_file)\n            tensors = load_file(str(model_dir / shard_file))\n\n            # Apply any merge ops targeting keys in this shard\n            for key in list(tensors.keys()):\n                ops_for_key = merge_ops.pop(key, [])\n                for op in ops_for_key:\n                    apply_merge_op(tensors, op)\n                    ops_applied += 1\n\n            # Write all tensors from this shard to output\n            for key, tensor in tensors.items():\n                writer.add_tensor(key, tensor)\n            del tensors\n            writer.flush()\n\n        # 7. Verify all ops were consumed\n        if merge_ops:\n            unconsumed = list(merge_ops.keys())\n            raise WeightsMergeError(\n                f\"Merge ops not applied — {len(unconsumed)} target keys not found in any shard: \"\n                f\"{unconsumed[:5]}{'...' if len(unconsumed) > 5 else ''}\"\n            )\n\n        logger.info(\"Applied %d/%d merge operations\", ops_applied, total_ops)\n\n        # 8. Finalize output shards\n        weight_map = writer.finalize()\n\n        # 9. Write index file (only for multi-shard output; HF convention\n        #    is no index for single-shard models)\n        shard_names = set(weight_map.values())\n        if len(shard_names) > 1:\n            index = {\n                \"metadata\": {\"total_size\": writer.total_size},\n                \"weight_map\": dict(sorted(weight_map.items())),\n            }\n            index_path = out / \"model.safetensors.index.json\"\n            with open(index_path, \"w\") as f:\n                json.dump(index, f, indent=2)\n\n        # 10. Save config, tokenizer, and model code files.\n        #     Copy config.json directly (safe — it's a single known file).\n        #     Copy *.py files for trust_remote_code model support.\n        #     We intentionally don't glob-copy all non-weight files to avoid\n        #     accidentally including stale index files or other artifacts that\n        #     could break downstream loaders like vLLM/SGLang.\n        src_config = model_dir / \"config.json\"\n        if src_config.exists():\n            copy_artifact_file(src_config, out / \"config.json\")\n        copy_model_code_files(model_dir, out)\n        save_tokenizer_and_processor(\n            base_model, out, is_multimodal_from_dict(config_dict), trust_remote_code\n        )\n\n        logger.info(\"Done — merged model saved to %s\", out)\n    except Exception:\n        cleanup_on_failure(out)\n        raise\n"
  },
  {
    "path": "tinker_cookbook/weights/_merge.py",
    "content": "\"\"\"LoRA adapter merge logic.\n\nProvides shared merge primitives used by all export strategies:\n\n- ``MergeProfile`` / ``detect_merge_profile``: model-specific merge configuration\n- ``MergeOp`` / ``plan_merge_ops`` / ``apply_merge_op``: plan-then-execute merge pipeline\n- ``merge_lora_matrices`` / ``expand_expert_lora_tensors``: low-level math utilities\n- ``merge_adapter_weights``: backward-compatible convenience wrapper\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING\n\nimport torch\n\nfrom tinker_cookbook.exceptions import WeightsMergeError\n\nif TYPE_CHECKING:\n    from collections.abc import Callable\n\n    # Profile detector: (config_dict, model_state_keys) -> MergeProfile | None\n    _ProfileDetector = Callable[[dict, \"set[str]\"], \"MergeProfile | None\"]\n\n# ---------------------------------------------------------------------------\n# MergeProfile — model-specific merge configuration\n# ---------------------------------------------------------------------------\n\n_VALID_EXPERT_LAYOUTS = frozenset({\"separate\", \"fused_interleaved\", \"fused_concatenated\"})\n\n\n@dataclass(frozen=True)\nclass MergeProfile:\n    \"\"\"Describes model-specific merge behavior.\n\n    Captures merge-level variation between model families: how adapter weight\n    names map to model weight names, and how expert weights are laid out.\n\n    Does NOT capture export-level concerns (output format, quantization,\n    shard layout) — those belong in export strategy modules.\n    \"\"\"\n\n    expert_layout: str = \"separate\"\n    \"\"\"How expert weights are arranged in the model.\n\n    - ``\"separate\"`` — individual weight per expert (Qwen3 MoE, DeepSeek)\n    - ``\"fused_interleaved\"`` — gate_up_proj with [g0, u0, g1, u1, ...] (GPT-OSS)\n    - ``\"fused_concatenated\"`` — gate_up_proj with [gate | up] (Qwen3.5, Qwen3-VL)\n    \"\"\"\n\n    extra_key_remaps: tuple[tuple[str, str], ...] = ()\n    \"\"\"Additional key remapping rules applied after standard remaps.\n\n    Each ``(old, new)`` pair is applied via ``str.replace`` on the target key.\n    Example: ``((\".attn\", \".self_attn\"),)`` for GPT-OSS.\n\n    Note: these remaps are applied to non-expert keys only. Expert keys go\n    through a separate remapping path (``w1→gate_proj``, etc.) that doesn't\n    use ``extra_key_remaps``. If a future model needs remaps on expert keys,\n    this should be extended.\n\n    Uses tuple-of-tuples rather than dict so ``MergeProfile`` stays hashable.\n    \"\"\"\n\n    has_language_model_prefix: bool = False\n    \"\"\"Whether model keys use ``model.language_model.`` prefix (vision models).\"\"\"\n\n\ndef detect_merge_profile(\n    model_config: dict,\n    model_state_keys: set[str],\n) -> MergeProfile:\n    \"\"\"Detect merge profile from model config and weight key names.\n\n    Tries each registered model-specific detector in order. The first one\n    that returns a profile wins. Falls back to :func:`_detect_default_profile`\n    if none match.\n\n    To add support for a new model family, write a detector function with\n    signature ``(dict, set[str]) -> MergeProfile | None`` and append it to\n    :data:`_PROFILE_DETECTORS`.\n\n    Works with both loaded models (full path) and safetensors headers\n    (shard path), since both can provide a config dict and key names.\n\n    Args:\n        model_config: Parsed ``config.json`` or equivalent dict. Uses the\n            ``\"architectures\"`` key for model family detection.\n        model_state_keys: Weight key names from the model state dict\n            or safetensors headers.\n    \"\"\"\n    for detector in _PROFILE_DETECTORS:\n        profile = detector(model_config, model_state_keys)\n        if profile is not None:\n            return profile\n    return _detect_default_profile(model_config, model_state_keys)\n\n\n# ---------------------------------------------------------------------------\n# Per-model profile detectors\n#\n# Each detector returns a MergeProfile if it recognizes the model, or None\n# to pass to the next detector. Add new detectors to _PROFILE_DETECTORS.\n# ---------------------------------------------------------------------------\n\n\ndef _detect_gpt_oss_profile(model_config: dict, model_state_keys: set[str]) -> MergeProfile | None:\n    \"\"\"Detect GPT-OSS models.\n\n    GPT-OSS uses ``.attn`` instead of ``.self_attn`` for attention layers, and\n    an interleaved ``[g0, u0, g1, u1, ...]`` layout for fused gate/up expert\n    projections.\n    \"\"\"\n    architectures = model_config.get(\"architectures\", [])\n    if not any(\"GptOss\" in a for a in architectures):\n        return None\n\n    has_fused = any(k.endswith(\".experts.gate_up_proj\") for k in model_state_keys)\n    has_lm_prefix = any(k.startswith(\"model.language_model.\") for k in model_state_keys)\n\n    return MergeProfile(\n        expert_layout=\"fused_interleaved\" if has_fused else \"separate\",\n        extra_key_remaps=((\".attn\", \".self_attn\"),),\n        has_language_model_prefix=has_lm_prefix,\n    )\n\n\ndef _detect_default_profile(model_config: dict, model_state_keys: set[str]) -> MergeProfile:\n    \"\"\"Default profile for models without special merge requirements.\n\n    Handles Qwen, DeepSeek, and other standard model families. Detects fused\n    expert layout (concatenated, not interleaved) and vision model prefix\n    from key names alone.\n    \"\"\"\n    has_fused = any(k.endswith(\".experts.gate_up_proj\") for k in model_state_keys)\n    has_lm_prefix = any(k.startswith(\"model.language_model.\") for k in model_state_keys)\n\n    return MergeProfile(\n        expert_layout=\"fused_concatenated\" if has_fused else \"separate\",\n        has_language_model_prefix=has_lm_prefix,\n    )\n\n\ndef _detect_deepseek_profile(model_config: dict, model_state_keys: set[str]) -> MergeProfile | None:\n    \"\"\"Detect DeepSeek V3/V3.1 models.\n\n    DeepSeek uses separate per-expert weights (not fused) and standard key\n    naming. Detection is based on ``model_type`` rather than architecture\n    strings for reliability across versions.\n    \"\"\"\n    if model_config.get(\"model_type\") not in (\"deepseek_v3\",):\n        return None\n\n    has_lm_prefix = any(k.startswith(\"model.language_model.\") for k in model_state_keys)\n\n    return MergeProfile(\n        expert_layout=\"separate\",\n        has_language_model_prefix=has_lm_prefix,\n    )\n\n\n# Detectors are tried in order. First match wins.\n_PROFILE_DETECTORS: list = [\n    _detect_gpt_oss_profile,\n    _detect_deepseek_profile,\n]\n\n\n# ---------------------------------------------------------------------------\n# MergeOp — a pending LoRA merge operation\n# ---------------------------------------------------------------------------\n\n\n@dataclass\nclass MergeOp:\n    \"\"\"A pending LoRA merge operation.\n\n    Stores only the small rank-sized LoRA matrices. The model-sized delta is\n    computed on-the-fly during :func:`apply_merge_op`, keeping peak memory\n    proportional to LoRA rank rather than model size.\n    \"\"\"\n\n    target_key: str\n\n    lora_A: torch.Tensor\n    \"\"\"Shape ``(rank, in_dim)`` for 2D ops, ``(num_experts, rank, in_dim)`` for 3D.\"\"\"\n\n    lora_B: torch.Tensor\n    \"\"\"Shape ``(out_dim, rank)`` for 2D ops, ``(num_experts, out_dim, rank)`` for 3D.\n    Pre-scaled by ``lora_alpha / r``.\"\"\"\n\n    is_expert_3d: bool = False\n    \"\"\"True for fused expert weights where lora_A/B are 3D.\"\"\"\n\n    fused_proj_idx: int | None = None\n    \"\"\"For fused gate/up projections: 0 = gate, 1 = up, None = not fused.\"\"\"\n\n    fused_proj_interleaved: bool = False\n    \"\"\"GPT-OSS stores fused gate/up projections interleaved rather than concatenated.\"\"\"\n\n\n# ---------------------------------------------------------------------------\n# Low-level math utilities\n# ---------------------------------------------------------------------------\n\n\ndef merge_lora_matrices(lora_A: torch.Tensor, lora_B: torch.Tensor) -> torch.Tensor:\n    \"\"\"Compute 2D LoRA delta: ``lora_B @ lora_A``.\n\n    Args:\n        lora_A: Shape ``(rank, in_dim)``.\n        lora_B: Shape ``(out_dim, rank)``, pre-scaled by ``alpha / r``.\n\n    Returns:\n        Delta tensor of shape ``(out_dim, in_dim)``.\n    \"\"\"\n    return lora_B @ lora_A\n\n\ndef expand_expert_lora_tensors(\n    lora_A: torch.Tensor, lora_B: torch.Tensor\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Broadcast shared expert LoRA tensors to match num_experts.\n\n    When one tensor has ``shape[0] == 1`` and the other has ``shape[0] > 1``,\n    expands the single-expert tensor to match.\n\n    Args:\n        lora_A: Shape ``(num_experts_a, rank, in_dim)``.\n        lora_B: Shape ``(num_experts_b, out_dim, rank)``.\n\n    Returns:\n        Tuple of ``(lora_A, lora_B)`` with matching ``shape[0]``.\n\n    Raises:\n        ValueError: If both tensors have ``shape[0] == 1``.\n    \"\"\"\n    if lora_A.shape[0] == 1 and lora_B.shape[0] == 1:\n        raise WeightsMergeError(\n            f\"Cannot broadcast expert LoRA: both A and B have 1 expert \"\n            f\"(lora_A: {lora_A.shape}, lora_B: {lora_B.shape})\"\n        )\n    if lora_A.shape[0] == 1:\n        lora_A = lora_A.expand(lora_B.shape[0], -1, -1)\n    elif lora_B.shape[0] == 1:\n        lora_B = lora_B.expand(lora_A.shape[0], -1, -1)\n    elif lora_A.shape[0] != lora_B.shape[0]:\n        raise WeightsMergeError(\n            f\"Expert count mismatch: lora_A has {lora_A.shape[0]} experts, \"\n            f\"lora_B has {lora_B.shape[0]} experts \"\n            f\"(lora_A: {lora_A.shape}, lora_B: {lora_B.shape})\"\n        )\n    return lora_A, lora_B\n\n\ndef apply_merged_weight(target: torch.Tensor, merged_lora: torch.Tensor) -> None:\n    \"\"\"Add a merged LoRA delta to a model weight tensor in-place.\"\"\"\n    if target.shape != merged_lora.shape:\n        raise WeightsMergeError(\n            f\"Shape mismatch: target {target.shape} vs merged LoRA {merged_lora.shape}\"\n        )\n    new_data = target.float() + merged_lora.float().to(target.device)\n    target.copy_(new_data.to(target.dtype))\n\n\n# ---------------------------------------------------------------------------\n# Plan + apply\n# ---------------------------------------------------------------------------\n\n\ndef plan_merge_ops(\n    adapter_weights: dict[str, torch.Tensor],\n    adapter_config: dict,\n    model_state_keys: set[str],\n    profile: MergeProfile,\n) -> dict[str, list[MergeOp]]:\n    \"\"\"Plan all merge operations without executing them.\n\n    Maps adapter weight names to model weight keys using the profile's\n    remapping rules, validates all target keys exist, and returns a dict\n    of pending merge operations grouped by target key.\n\n    Args:\n        adapter_weights: LoRA weight tensors from the adapter.\n        adapter_config: Adapter config with ``lora_alpha`` and ``r`` keys.\n        model_state_keys: Set of weight key names in the base model.\n        profile: Model-specific merge configuration.\n\n    Returns:\n        Mapping from model weight key to list of :class:`MergeOp` targeting it.\n\n    Raises:\n        KeyError: If adapter config is missing required keys, or adapter\n            weights map to keys not found in the model.\n        ValueError: If expert LoRA tensors have unexpected shapes, or\n            ``profile.expert_layout`` is invalid.\n    \"\"\"\n    for key in (\"lora_alpha\", \"r\"):\n        if key not in adapter_config:\n            raise WeightsMergeError(f\"Adapter config missing required key: {key!r}\")\n\n    if profile.expert_layout not in _VALID_EXPERT_LAYOUTS:\n        raise WeightsMergeError(\n            f\"Invalid expert_layout {profile.expert_layout!r}. \"\n            f\"Must be one of: {sorted(_VALID_EXPERT_LAYOUTS)}\"\n        )\n\n    scaling = adapter_config[\"lora_alpha\"] / adapter_config[\"r\"]\n    adapter_weight_names = [n.replace(\".lora_A\", \"\") for n in adapter_weights if \".lora_A\" in n]\n\n    if not adapter_weight_names:\n        import logging\n\n        logging.getLogger(__name__).warning(\n            \"No LoRA weights found in adapter (no keys containing '.lora_A'). \"\n            \"The output model will be identical to the base model. \"\n            \"Check that the adapter path points to a valid Tinker LoRA adapter.\"\n        )\n\n    is_fused = profile.expert_layout in (\"fused_interleaved\", \"fused_concatenated\")\n    is_interleaved = profile.expert_layout == \"fused_interleaved\"\n\n    # Standard name remapping (order matters: strip prefix before vision remap)\n    name_remaps: list[tuple[str, str]] = [\n        (\"base_model.model.\", \"\"),\n        (\"model.unembed_tokens\", \"lm_head\"),\n    ]\n    if profile.has_language_model_prefix:\n        name_remaps.append((\"model.\", \"model.language_model.\"))\n\n    ops: dict[str, list[MergeOp]] = {}\n\n    for n in adapter_weight_names:\n        target_key = n\n        for old, new in name_remaps:\n            target_key = target_key.replace(old, new)\n\n        lora_A = adapter_weights[n.replace(\".weight\", \".lora_A.weight\")].float()\n        lora_B = adapter_weights[n.replace(\".weight\", \".lora_B.weight\")].float() * scaling\n\n        if \".experts\" not in n:\n            _plan_non_expert_op(target_key, lora_A, lora_B, n, profile, model_state_keys, ops)\n        else:\n            _plan_expert_ops(\n                target_key, lora_A, lora_B, n, model_state_keys, ops, is_fused, is_interleaved\n            )\n\n    return ops\n\n\ndef _plan_non_expert_op(\n    target_key: str,\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    adapter_name: str,\n    profile: MergeProfile,\n    model_state_keys: set[str],\n    ops: dict[str, list[MergeOp]],\n) -> None:\n    \"\"\"Plan a merge op for a standard (non-expert) linear layer.\"\"\"\n    for old, new in profile.extra_key_remaps:\n        target_key = target_key.replace(old, new)\n\n    if target_key not in model_state_keys:\n        raise WeightsMergeError(\n            f\"Adapter weight {adapter_name!r} mapped to {target_key!r} \"\n            f\"which does not exist in the model state dict\"\n        )\n    ops.setdefault(target_key, []).append(\n        MergeOp(target_key=target_key, lora_A=lora_A, lora_B=lora_B)\n    )\n\n\ndef _plan_expert_ops(\n    target_key: str,\n    lora_A: torch.Tensor,\n    lora_B: torch.Tensor,\n    adapter_name: str,\n    model_state_keys: set[str],\n    ops: dict[str, list[MergeOp]],\n    is_fused: bool,\n    is_interleaved: bool,\n) -> None:\n    \"\"\"Plan merge ops for expert weights (separate or fused).\"\"\"\n    if lora_A.ndim != 3 or lora_B.ndim != 3:\n        raise WeightsMergeError(\n            f\"Expert LoRA weights must be 3D, got lora_A: {lora_A.shape}, lora_B: {lora_B.shape}\"\n        )\n    lora_A, lora_B = expand_expert_lora_tensors(lora_A, lora_B)\n\n    # Expert weight name remapping\n    target_key = target_key.replace(\".w1.weight\", \".gate_proj.weight\")\n    target_key = target_key.replace(\".w3.weight\", \".up_proj.weight\")\n    target_key = target_key.replace(\".w2.weight\", \".down_proj.weight\")\n\n    if not is_fused:\n        # Separate per-expert weights: create one 2D MergeOp per expert\n        for exp_idx in range(lora_A.shape[0]):\n            target_key_exp = target_key.replace(\".experts\", f\".experts.{exp_idx}\")\n            if target_key_exp not in model_state_keys:\n                raise WeightsMergeError(\n                    f\"Adapter weight {adapter_name!r} mapped to {target_key_exp!r} \"\n                    f\"which does not exist in the model state dict\"\n                )\n            ops.setdefault(target_key_exp, []).append(\n                MergeOp(\n                    target_key=target_key_exp,\n                    lora_A=lora_A[exp_idx],\n                    lora_B=lora_B[exp_idx],\n                )\n            )\n    else:\n        # Fused expert weights: create one 3D MergeOp\n        fused_proj_idx: int | None = None\n        if target_key.endswith(\".gate_proj.weight\"):\n            fused_proj_idx = 0\n            target_key = target_key.replace(\".gate_proj.weight\", \".gate_up_proj\")\n        elif target_key.endswith(\".up_proj.weight\"):\n            fused_proj_idx = 1\n            target_key = target_key.replace(\".up_proj.weight\", \".gate_up_proj\")\n        else:\n            target_key = target_key.replace(\".down_proj.weight\", \".down_proj\")\n\n        if target_key not in model_state_keys:\n            raise WeightsMergeError(\n                f\"Adapter weight {adapter_name!r} mapped to {target_key!r} \"\n                f\"which does not exist in the model state dict\"\n            )\n        ops.setdefault(target_key, []).append(\n            MergeOp(\n                target_key=target_key,\n                lora_A=lora_A,\n                lora_B=lora_B,\n                is_expert_3d=True,\n                fused_proj_idx=fused_proj_idx,\n                fused_proj_interleaved=is_interleaved,\n            )\n        )\n\n\ndef validate_merge_op_shapes(\n    ops: dict[str, list[MergeOp]],\n    model_shapes: dict[str, tuple[int, ...]],\n) -> None:\n    \"\"\"Validate all merge op output shapes against model weight shapes upfront.\n\n    Call this after :func:`plan_merge_ops` and before processing any shards.\n    Catches shape mismatches early, before expensive shard I/O begins.\n\n    Args:\n        ops: Mapping from target key to merge ops (from :func:`plan_merge_ops`).\n        model_shapes: Mapping from weight key to shape (from\n            :func:`~tinker_cookbook.weights._artifacts.get_model_state_shapes`).\n\n    Raises:\n        ValueError: If any merge op's delta shape doesn't match its target.\n    \"\"\"\n    for target_key, op_list in ops.items():\n        target_shape = model_shapes[target_key]\n        for op in op_list:\n            if op.is_expert_3d:\n                # bmm(A.T, B.T) → (num_experts, in_dim, out_dim)\n                n_exp, rank, in_dim = op.lora_A.shape\n                _, out_dim, _ = op.lora_B.shape\n                delta_shape = (n_exp, in_dim, out_dim)\n\n                if op.fused_proj_idx is not None:\n                    # Delta targets a slice of the fused tensor\n                    if op.fused_proj_interleaved:\n                        # Interleaved: target[:, :, idx::2] has shape (n, d, fused//2)\n                        expected = (target_shape[0], target_shape[1], target_shape[2] // 2)\n                    else:\n                        # Concatenated: target[:, :, start:start+half] has shape (n, d, fused//2)\n                        expected = (target_shape[0], target_shape[1], target_shape[2] // 2)\n                else:\n                    expected = target_shape\n\n                if delta_shape != expected:\n                    raise WeightsMergeError(\n                        f\"Shape mismatch for {target_key!r}: \"\n                        f\"merge op produces {delta_shape} but target \"\n                        f\"{'slice ' if op.fused_proj_idx is not None else ''}\"\n                        f\"expects {expected}\"\n                    )\n            else:\n                # 2D: delta = lora_B @ lora_A → (out_dim, in_dim)\n                delta_shape = (op.lora_B.shape[0], op.lora_A.shape[1])\n                if delta_shape != target_shape:\n                    raise WeightsMergeError(\n                        f\"Shape mismatch for {target_key!r}: \"\n                        f\"merge op produces {delta_shape} but target expects {target_shape}\"\n                    )\n\n\ndef apply_merge_op(tensors: dict[str, torch.Tensor], op: MergeOp) -> None:\n    \"\"\"Apply a single merge operation to a dict of tensors.\n\n    Computes the LoRA delta on-the-fly and merges into the target tensor.\n    Works with full model state dicts or individual shard tensor dicts.\n\n    Args:\n        tensors: Mutable dict of tensors (e.g. from :func:`safetensors.torch.load_file`\n            or :meth:`torch.nn.Module.state_dict`). Modified in-place.\n        op: The merge operation to apply.\n\n    Raises:\n        ValueError: If tensor shapes are incompatible.\n    \"\"\"\n    target = tensors[op.target_key]\n\n    if op.is_expert_3d:\n        # (num_experts, rank, in_dim), (num_experts, out_dim, rank)\n        # → (num_experts, in_dim, out_dim) via bmm of transposed\n        delta = torch.bmm(op.lora_A.transpose(-1, -2), op.lora_B.transpose(-1, -2))\n\n        if op.fused_proj_idx is not None:\n            if op.fused_proj_interleaved:\n                target_view = target[:, :, op.fused_proj_idx :: 2]\n            else:\n                proj_width = target.shape[-1] // 2\n                start = op.fused_proj_idx * proj_width\n                target_view = target[:, :, start : start + proj_width]\n            apply_merged_weight(target_view, delta)\n        else:\n            apply_merged_weight(target, delta)\n    else:\n        # 2D: standard linear or per-expert (already sliced during planning)\n        delta = merge_lora_matrices(op.lora_A, op.lora_B)\n        apply_merged_weight(tensors[op.target_key], delta)\n\n\n# ---------------------------------------------------------------------------\n# Backward-compatible convenience wrapper\n# ---------------------------------------------------------------------------\n\n\ndef merge_adapter_weights(\n    base_model: torch.nn.Module, adapter_weights: dict[str, torch.Tensor], config: dict\n) -> None:\n    \"\"\"Merge LoRA adapter weights into a base model's state dict in-place.\n\n    Backward-compatible wrapper around :func:`plan_merge_ops` and\n    :func:`apply_merge_op`.\n\n    Handles:\n    - Standard (non-expert) linear layers\n    - Separate per-expert weights (Qwen3 MoE, DeepSeek, Kimi)\n    - Fused expert weights with interleaved layout (GPT-OSS)\n    - Fused expert weights with concatenated layout (Qwen3.5, Qwen3-VL)\n    - Vision model name prefix remapping\n    - GPT-OSS attention name remapping\n\n    Args:\n        base_model: The HuggingFace model to merge into.\n        adapter_weights: Dict of LoRA weight tensors from the adapter.\n        config: Adapter config dict with ``lora_alpha`` and ``r`` keys.\n\n    Raises:\n        KeyError: If required config keys are missing or adapter weight\n            names don't map to any model weight.\n        ValueError: If tensor shapes are incompatible.\n    \"\"\"\n    model_state_dict = base_model.state_dict()\n    model_state_keys = set(model_state_dict.keys())\n\n    # Build a config dict for detect_merge_profile. Prefer the model's HF\n    # config (which has the real architectures list) over fragile class name\n    # string matching.\n    config_obj = getattr(base_model, \"config\", None)\n    if config_obj is not None and hasattr(config_obj, \"to_dict\"):\n        model_config = config_obj.to_dict()\n    else:\n        # Fallback for non-HF models (e.g. test mocks)\n        is_gpt_oss = \"GptOss\" in str(type(base_model))\n        model_config = {\"architectures\": [\"GptOssForCausalLM\"] if is_gpt_oss else []}\n    profile = detect_merge_profile(model_config, model_state_keys)\n\n    ops = plan_merge_ops(adapter_weights, config, model_state_keys, profile)\n    for op_list in ops.values():\n        for op in op_list:\n            apply_merge_op(model_state_dict, op)\n"
  },
  {
    "path": "tinker_cookbook/weights/_publish.py",
    "content": "\"\"\"Publish model weights to HuggingFace Hub.\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\n\nfrom huggingface_hub import HfApi\n\n\ndef publish_to_hf_hub(\n    *,\n    model_path: str,\n    repo_id: str,\n    private: bool = True,\n    token: str | None = None,\n) -> str:\n    \"\"\"Push a model or adapter directory to HuggingFace Hub.\n\n    Works with outputs from :func:`build_hf_model`, :func:`build_lora_adapter`,\n    or any HuggingFace-compatible model directory.\n\n    Args:\n        model_path: Local path to the model or adapter directory to upload.\n        repo_id: HuggingFace Hub repository ID (e.g. ``\"user/my-model\"``).\n        private: Whether the repository should be private. Defaults to\n            ``True`` for safety.\n        token: HuggingFace API token. If ``None`` (default), uses the\n            ``HF_TOKEN`` environment variable or cached login from\n            ``hf auth login``.\n\n    Returns:\n        URL of the published repository.\n    \"\"\"\n    path = Path(model_path)\n    if not path.is_dir():\n        raise FileNotFoundError(f\"model_path does not exist or is not a directory: {model_path}\")\n\n    api = HfApi(token=token)\n    api.create_repo(repo_id=repo_id, repo_type=\"model\", private=private, exist_ok=True)\n    api.upload_folder(folder_path=str(path), repo_id=repo_id, repo_type=\"model\")\n\n    return f\"https://huggingface.co/{repo_id}\"\n"
  },
  {
    "path": "tinker_cookbook/weights/artifacts_test.py",
    "content": "\"\"\"Unit tests for model artifact utilities.\n\nUses temporary directories and synthetic safetensors files — no network or\nGPU required.\n\"\"\"\n\nimport json\nfrom pathlib import Path\n\nimport pytest\nimport torch\nfrom safetensors.torch import save_file\n\nfrom tinker_cookbook.weights._artifacts import (\n    ShardWriter,\n    copy_model_code_files,\n    get_model_state_keys,\n    get_model_state_shapes,\n    get_shard_files,\n    load_adapter_weights,\n)\n\n# ---------------------------------------------------------------------------\n# get_model_state_keys\n# ---------------------------------------------------------------------------\n\n\nclass TestGetModelStateKeys:\n    def test_reads_keys_from_single_shard(self, tmp_path: Path):\n        tensors = {\"layer.0.weight\": torch.zeros(4, 4), \"layer.0.bias\": torch.zeros(4)}\n        save_file(tensors, str(tmp_path / \"model.safetensors\"))\n\n        keys = get_model_state_keys(tmp_path)\n        assert keys == {\"layer.0.weight\", \"layer.0.bias\"}\n\n    def test_reads_keys_from_multiple_shards(self, tmp_path: Path):\n        save_file({\"a.weight\": torch.zeros(2)}, str(tmp_path / \"model-00001-of-00002.safetensors\"))\n        save_file({\"b.weight\": torch.zeros(3)}, str(tmp_path / \"model-00002-of-00002.safetensors\"))\n\n        keys = get_model_state_keys(tmp_path)\n        assert keys == {\"a.weight\", \"b.weight\"}\n\n    def test_raises_if_no_safetensors(self, tmp_path: Path):\n        with pytest.raises(FileNotFoundError, match=r\"No \\.safetensors files\"):\n            get_model_state_keys(tmp_path)\n\n\n# ---------------------------------------------------------------------------\n# get_model_state_shapes\n# ---------------------------------------------------------------------------\n\n\nclass TestGetModelStateShapes:\n    def test_reads_shapes(self, tmp_path: Path):\n        tensors = {\"weight\": torch.zeros(8, 4), \"bias\": torch.zeros(8)}\n        save_file(tensors, str(tmp_path / \"model.safetensors\"))\n\n        shapes = get_model_state_shapes(tmp_path)\n        assert shapes == {\"weight\": (8, 4), \"bias\": (8,)}\n\n    def test_reads_shapes_across_shards(self, tmp_path: Path):\n        save_file({\"a\": torch.zeros(2, 3)}, str(tmp_path / \"shard-1.safetensors\"))\n        save_file({\"b\": torch.zeros(4)}, str(tmp_path / \"shard-2.safetensors\"))\n\n        shapes = get_model_state_shapes(tmp_path)\n        assert shapes == {\"a\": (2, 3), \"b\": (4,)}\n\n    def test_helpful_error_for_bin_files(self, tmp_path: Path):\n        (tmp_path / \"pytorch_model.bin\").write_bytes(b\"fake\")\n        with pytest.raises(FileNotFoundError, match=r\"\\.bin file.*merge_strategy='full'\"):\n            get_model_state_shapes(tmp_path)\n\n    def test_helpful_error_for_empty_dir(self, tmp_path: Path):\n        with pytest.raises(FileNotFoundError, match=r\"merge_strategy='full'\"):\n            get_model_state_shapes(tmp_path)\n\n\n# ---------------------------------------------------------------------------\n# get_shard_files\n# ---------------------------------------------------------------------------\n\n\nclass TestGetShardFiles:\n    def test_reads_from_index_json(self, tmp_path: Path):\n        index = {\n            \"weight_map\": {\n                \"a.weight\": \"model-00001-of-00002.safetensors\",\n                \"b.weight\": \"model-00002-of-00002.safetensors\",\n                \"c.weight\": \"model-00001-of-00002.safetensors\",\n            }\n        }\n        (tmp_path / \"model.safetensors.index.json\").write_text(json.dumps(index))\n\n        files = get_shard_files(tmp_path)\n        assert files == [\"model-00001-of-00002.safetensors\", \"model-00002-of-00002.safetensors\"]\n\n    def test_falls_back_to_glob(self, tmp_path: Path):\n        save_file({\"a\": torch.zeros(1)}, str(tmp_path / \"model.safetensors\"))\n\n        files = get_shard_files(tmp_path)\n        assert files == [\"model.safetensors\"]\n\n    def test_raises_if_no_files(self, tmp_path: Path):\n        with pytest.raises(FileNotFoundError):\n            get_shard_files(tmp_path)\n\n\n# ---------------------------------------------------------------------------\n# ShardWriter\n# ---------------------------------------------------------------------------\n\n\nclass TestShardWriter:\n    def test_single_shard_named_without_index(self, tmp_path: Path):\n        writer = ShardWriter(tmp_path)\n        writer.add_tensor(\"a.weight\", torch.zeros(4))\n        writer.add_tensor(\"b.weight\", torch.ones(4))\n        weight_map = writer.finalize()\n\n        assert (tmp_path / \"model.safetensors\").exists()\n        assert weight_map == {\n            \"a.weight\": \"model.safetensors\",\n            \"b.weight\": \"model.safetensors\",\n        }\n\n    def test_multiple_shards_when_exceeding_max_size(self, tmp_path: Path):\n        # Each float32 tensor of 1024 elements = 4096 bytes\n        writer = ShardWriter(tmp_path, max_shard_size=4096)\n        writer.add_tensor(\"a.weight\", torch.zeros(1024))  # 4096 bytes, fits\n        writer.add_tensor(\"b.weight\", torch.zeros(1024))  # triggers flush of a, then b pending\n\n        weight_map = writer.finalize()\n        assert len(set(weight_map.values())) == 2\n        assert \"model-00001-of-00002.safetensors\" in weight_map.values()\n        assert \"model-00002-of-00002.safetensors\" in weight_map.values()\n\n    def test_total_size_tracks_bytes(self, tmp_path: Path):\n        writer = ShardWriter(tmp_path)\n        writer.add_tensor(\"x\", torch.zeros(100, dtype=torch.float32))  # 400 bytes\n        assert writer.total_size == 400\n\n    def test_temp_files_cleaned_up(self, tmp_path: Path):\n        writer = ShardWriter(tmp_path)\n        writer.add_tensor(\"a\", torch.zeros(4))\n        writer.finalize()\n\n        # No .tmp files should remain\n        tmp_files = list(tmp_path.glob(\"*.tmp.*\"))\n        assert tmp_files == []\n\n    def test_empty_writer_produces_no_files(self, tmp_path: Path):\n        writer = ShardWriter(tmp_path)\n        weight_map = writer.finalize()\n        assert weight_map == {}\n        assert list(tmp_path.glob(\"*.safetensors\")) == []\n\n\n# ---------------------------------------------------------------------------\n# load_adapter_weights\n# ---------------------------------------------------------------------------\n\n\nclass TestLoadAdapterWeights:\n    def test_loads_weights_and_config(self, tmp_path: Path):\n        weights = {\"lora_A\": torch.ones(2, 4), \"lora_B\": torch.ones(8, 2)}\n        save_file(weights, str(tmp_path / \"adapter_model.safetensors\"))\n        (tmp_path / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": 2}))\n\n        loaded_weights, config = load_adapter_weights(tmp_path)\n        assert set(loaded_weights.keys()) == {\"lora_A\", \"lora_B\"}\n        assert config[\"r\"] == 2\n\n    def test_missing_safetensors_raises(self, tmp_path: Path):\n        (tmp_path / \"adapter_config.json\").write_text(\"{}\")\n        with pytest.raises(FileNotFoundError, match=r\"adapter_model\\.safetensors\"):\n            load_adapter_weights(tmp_path)\n\n    def test_missing_config_raises(self, tmp_path: Path):\n        save_file({\"x\": torch.zeros(1)}, str(tmp_path / \"adapter_model.safetensors\"))\n        with pytest.raises(FileNotFoundError, match=r\"adapter_config\\.json\"):\n            load_adapter_weights(tmp_path)\n\n\n# ---------------------------------------------------------------------------\n# copy_model_code_files\n# ---------------------------------------------------------------------------\n\n\nclass TestCopyModelCodeFiles:\n    def test_copies_only_py_files(self, tmp_path: Path):\n        src = tmp_path / \"src\"\n        dst = tmp_path / \"dst\"\n        src.mkdir()\n        dst.mkdir()\n\n        (src / \"modeling_custom.py\").write_text(\"# model code\")\n        (src / \"configuration_custom.py\").write_text(\"# config code\")\n        # Non-py files should NOT be copied\n        (src / \"config.json\").write_text('{\"model_type\": \"test\"}')\n        (src / \"tokenizer.model\").write_text(\"tokenizer data\")\n        save_file({\"x\": torch.zeros(1)}, str(src / \"model.safetensors\"))\n\n        copy_model_code_files(src, dst)\n\n        assert (dst / \"modeling_custom.py\").exists()\n        assert (dst / \"configuration_custom.py\").exists()\n        assert not (dst / \"config.json\").exists()\n        assert not (dst / \"tokenizer.model\").exists()\n        assert not (dst / \"model.safetensors\").exists()\n\n    def test_does_not_overwrite_existing(self, tmp_path: Path):\n        src = tmp_path / \"src\"\n        dst = tmp_path / \"dst\"\n        src.mkdir()\n        dst.mkdir()\n\n        (src / \"modeling.py\").write_text(\"source\")\n        (dst / \"modeling.py\").write_text(\"existing\")\n\n        copy_model_code_files(src, dst)\n\n        assert (dst / \"modeling.py\").read_text() == \"existing\"\n"
  },
  {
    "path": "tinker_cookbook/weights/download_test.py",
    "content": "\"\"\"Tests for the download function.\"\"\"\n\nimport tarfile\nimport tempfile\nfrom pathlib import Path\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom tinker_cookbook.exceptions import WeightsDownloadError\nfrom tinker_cookbook.weights._download import _safe_extract_tar, download\n\n\nclass TestSafeExtractTar:\n    \"\"\"Security validation for tar extraction.\"\"\"\n\n    def test_rejects_symlinks(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            archive_path = root / \"bad.tar\"\n            extract_dir = root / \"extract\"\n            extract_dir.mkdir()\n\n            target = root / \"target.txt\"\n            target.write_text(\"target\")\n            link = root / \"link\"\n            link.symlink_to(target)\n\n            with tarfile.open(archive_path, \"w\") as tar:\n                tar.add(link, arcname=\"link\")\n\n            with pytest.raises(WeightsDownloadError, match=\"symlink\"):\n                _safe_extract_tar(archive_path, extract_dir)\n\n    def test_rejects_path_traversal(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            archive_path = root / \"bad.tar\"\n            extract_dir = root / \"extract\"\n            extract_dir.mkdir()\n\n            normal_file = root / \"normal.txt\"\n            normal_file.write_text(\"content\")\n\n            with tarfile.open(archive_path, \"w\") as tar:\n                tar.add(normal_file, arcname=\"../../../etc/passwd\")\n\n            with pytest.raises(WeightsDownloadError, match=\"path traversal\"):\n                _safe_extract_tar(archive_path, extract_dir)\n\n    def test_extracts_safe_archive(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            archive_path = root / \"good.tar\"\n            extract_dir = root / \"extract\"\n            extract_dir.mkdir()\n\n            content_file = root / \"data.txt\"\n            content_file.write_text(\"hello\")\n\n            with tarfile.open(archive_path, \"w\") as tar:\n                tar.add(content_file, arcname=\"data.txt\")\n\n            _safe_extract_tar(archive_path, extract_dir)\n            assert (extract_dir / \"data.txt\").exists()\n\n\nclass TestDownload:\n    \"\"\"Tests for the download function with mocked Tinker SDK.\"\"\"\n\n    def test_downloads_and_extracts(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            root = Path(tmpdir)\n            archive_path = root / \"archive.tar\"\n            content_dir = root / \"content\"\n            content_dir.mkdir()\n            (content_dir / \"adapter_model.safetensors\").write_text(\"fake\")\n            (content_dir / \"adapter_config.json\").write_text(\"{}\")\n\n            with tarfile.open(archive_path, \"w\") as tar:\n                tar.add(\n                    content_dir / \"adapter_model.safetensors\", arcname=\"adapter_model.safetensors\"\n                )\n                tar.add(content_dir / \"adapter_config.json\", arcname=\"adapter_config.json\")\n\n            output_dir = root / \"output\"\n\n            mock_response = MagicMock()\n            mock_response.url = f\"file://{archive_path}\"\n\n            mock_future = MagicMock()\n            mock_future.result.return_value = mock_response\n\n            mock_rest_client = MagicMock()\n            mock_rest_client.get_checkpoint_archive_url_from_tinker_path.return_value = mock_future\n\n            mock_service_client = MagicMock()\n            mock_service_client.create_rest_client.return_value = mock_rest_client\n\n            def fake_urlretrieve(url: str, dest: str) -> None:\n                import shutil\n\n                shutil.copy2(str(archive_path), dest)\n\n            with (\n                patch(\n                    \"tinker_cookbook.weights._download.tinker.ServiceClient\",\n                    return_value=mock_service_client,\n                ),\n                patch(\n                    \"tinker_cookbook.weights._download.urllib.request.urlretrieve\",\n                    fake_urlretrieve,\n                ),\n            ):\n                result = download(\n                    tinker_path=\"tinker://fake-run/sampler_weights/final\",\n                    output_dir=str(output_dir),\n                )\n\n            assert result == str(output_dir)\n            assert (output_dir / \"adapter_model.safetensors\").exists()\n            assert (output_dir / \"adapter_config.json\").exists()\n"
  },
  {
    "path": "tinker_cookbook/weights/export_test.py",
    "content": "\"\"\"Unit tests for build_hf_model strategy dispatch and the sharded export path.\n\nUses synthetic safetensors files and adapter weights to test the shard-by-shard\npipeline end-to-end without requiring real HF models or network access.\n\"\"\"\n\nimport json\nfrom pathlib import Path\n\nimport pytest\nimport torch\nfrom safetensors.torch import load_file, save_file\n\nfrom tinker_cookbook.weights._export import build_hf_model\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _create_synthetic_model(model_dir: Path, config_dict: dict, state_dict: dict) -> None:\n    \"\"\"Create a minimal synthetic HF model directory.\n\n    Writes config.json, a single safetensors shard, and a minimal tokenizer.\n    \"\"\"\n    model_dir.mkdir(parents=True, exist_ok=True)\n    (model_dir / \"config.json\").write_text(json.dumps(config_dict))\n    save_file(state_dict, str(model_dir / \"model.safetensors\"))\n    # Minimal tokenizer files so AutoTokenizer doesn't fail\n    (model_dir / \"tokenizer_config.json\").write_text(\n        json.dumps({\"tokenizer_class\": \"PreTrainedTokenizerFast\"})\n    )\n    (model_dir / \"tokenizer.json\").write_text(\n        json.dumps(\n            {\n                \"version\": \"1.0\",\n                \"model\": {\"type\": \"BPE\", \"vocab\": {\"a\": 0, \"b\": 1}, \"merges\": []},\n                \"added_tokens\": [],\n            }\n        )\n    )\n\n\ndef _create_sharded_model(\n    model_dir: Path, config_dict: dict, shards: dict[str, dict[str, torch.Tensor]]\n) -> None:\n    \"\"\"Create a synthetic model with multiple safetensors shards.\"\"\"\n    model_dir.mkdir(parents=True, exist_ok=True)\n    (model_dir / \"config.json\").write_text(json.dumps(config_dict))\n\n    weight_map: dict[str, str] = {}\n    for shard_name, tensors in shards.items():\n        save_file(tensors, str(model_dir / shard_name))\n        for key in tensors:\n            weight_map[key] = shard_name\n\n    index = {\"metadata\": {\"total_size\": 0}, \"weight_map\": weight_map}\n    (model_dir / \"model.safetensors.index.json\").write_text(json.dumps(index))\n\n    (model_dir / \"tokenizer_config.json\").write_text(\n        json.dumps({\"tokenizer_class\": \"PreTrainedTokenizerFast\"})\n    )\n    (model_dir / \"tokenizer.json\").write_text(\n        json.dumps(\n            {\n                \"version\": \"1.0\",\n                \"model\": {\"type\": \"BPE\", \"vocab\": {\"a\": 0, \"b\": 1}, \"merges\": []},\n                \"added_tokens\": [],\n            }\n        )\n    )\n\n\ndef _create_adapter(adapter_dir: Path, weights: dict[str, torch.Tensor], config: dict) -> None:\n    \"\"\"Create a synthetic adapter directory.\"\"\"\n    adapter_dir.mkdir(parents=True, exist_ok=True)\n    save_file(weights, str(adapter_dir / \"adapter_model.safetensors\"))\n    (adapter_dir / \"adapter_config.json\").write_text(json.dumps(config))\n\n\n# ---------------------------------------------------------------------------\n# Strategy dispatch\n# ---------------------------------------------------------------------------\n\n\nclass TestBuildHfModelDispatch:\n    def test_invalid_strategy_raises(self, tmp_path: Path):\n        with pytest.raises(ValueError, match=\"merge_strategy\"):\n            build_hf_model(\n                base_model=str(tmp_path),\n                adapter_path=str(tmp_path),\n                output_path=str(tmp_path / \"out\"),\n                merge_strategy=\"invalid\",\n            )\n\n    def test_dequantize_raises_not_implemented(self, tmp_path: Path):\n        with pytest.raises(NotImplementedError, match=\"dequantize\"):\n            build_hf_model(\n                base_model=str(tmp_path),\n                adapter_path=str(tmp_path),\n                output_path=str(tmp_path / \"out\"),\n                dequantize=True,\n            )\n\n    def test_invalid_dtype_raises(self, tmp_path: Path):\n        with pytest.raises(ValueError, match=\"dtype\"):\n            build_hf_model(\n                base_model=str(tmp_path),\n                adapter_path=str(tmp_path),\n                output_path=str(tmp_path / \"out\"),\n                dtype=\"float8\",\n            )\n\n\n# ---------------------------------------------------------------------------\n# Shard-by-shard end-to-end (single shard)\n# ---------------------------------------------------------------------------\n\n\nclass TestBuildShardedSingleShard:\n    \"\"\"End-to-end test of shard-by-shard merge with a single-shard synthetic model.\"\"\"\n\n    def test_merges_adapter_into_single_shard(self, tmp_path: Path):\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n\n        # Create a synthetic model with one linear layer\n        config = {\"architectures\": [\"TestModel\"], \"model_type\": \"test\"}\n        state_dict = {\n            \"model.layers.0.self_attn.q_proj.weight\": torch.zeros(8, 4, dtype=torch.float32),\n            \"model.layers.0.mlp.gate_proj.weight\": torch.zeros(8, 4, dtype=torch.float32),\n        }\n        _create_synthetic_model(model_dir, config, state_dict)\n\n        # Create adapter targeting q_proj\n        adapter_weights = {\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n        _create_adapter(adapter_dir, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n        # Run sharded merge\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            merge_strategy=\"shard\",\n        )\n\n        # Verify output structure\n        assert (output_dir / \"config.json\").exists()\n        assert (output_dir / \"model.safetensors\").exists() or (\n            output_dir / \"model.safetensors.index.json\"\n        ).exists()\n\n        # Load and verify merged weights\n        out_tensors = _load_output_tensors(output_dir)\n        q_proj = out_tensors[\"model.layers.0.self_attn.q_proj.weight\"]\n        gate_proj = out_tensors[\"model.layers.0.mlp.gate_proj.weight\"]\n\n        # q_proj should have LoRA delta applied (all ones from B @ A)\n        assert q_proj.abs().sum() > 0\n        assert torch.allclose(q_proj, torch.ones(8, 4))\n\n        # gate_proj should be unchanged (no adapter targeting it)\n        assert gate_proj.abs().sum() == 0\n\n\n# ---------------------------------------------------------------------------\n# Shard-by-shard end-to-end (multiple shards)\n# ---------------------------------------------------------------------------\n\n\nclass TestBuildShardedMultiShard:\n    \"\"\"End-to-end test of shard-by-shard merge with a multi-shard synthetic model.\"\"\"\n\n    def test_merges_across_shards(self, tmp_path: Path):\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n\n        # Create a model with weights split across two shards\n        config = {\"architectures\": [\"TestModel\"], \"model_type\": \"test\"}\n        shards = {\n            \"model-00001-of-00002.safetensors\": {\n                \"model.layers.0.self_attn.q_proj.weight\": torch.zeros(8, 4, dtype=torch.float32),\n            },\n            \"model-00002-of-00002.safetensors\": {\n                \"model.layers.0.mlp.gate_proj.weight\": torch.zeros(8, 4, dtype=torch.float32),\n            },\n        }\n        _create_sharded_model(model_dir, config, shards)\n\n        # Adapter targets weights in both shards\n        adapter_weights = {\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": torch.ones(1, 4)\n            * 0.5,\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n            \"base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight\": torch.ones(1, 4) * 0.3,\n            \"base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n        _create_adapter(adapter_dir, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            merge_strategy=\"shard\",\n        )\n\n        out_tensors = _load_output_tensors(output_dir)\n\n        # Both should have their respective deltas\n        q_proj = out_tensors[\"model.layers.0.self_attn.q_proj.weight\"]\n        gate_proj = out_tensors[\"model.layers.0.mlp.gate_proj.weight\"]\n\n        assert torch.allclose(q_proj, torch.full((8, 4), 0.5), atol=1e-6)\n        assert torch.allclose(gate_proj, torch.full((8, 4), 0.3), atol=1e-6)\n\n\n# ---------------------------------------------------------------------------\n# Shard-by-shard with separate experts\n# ---------------------------------------------------------------------------\n\n\nclass TestBuildShardedSeparateExperts:\n    \"\"\"Shard-by-shard merge with per-expert weights (no fused gate_up_proj).\"\"\"\n\n    def test_merges_per_expert_weights(self, tmp_path: Path):\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n\n        num_experts = 2\n        config = {\"architectures\": [\"TestMoEModel\"], \"model_type\": \"test\"}\n        state_dict = {\n            f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\": torch.zeros(\n                8, 4, dtype=torch.float32\n            )\n            for i in range(num_experts)\n        }\n        _create_synthetic_model(model_dir, config, state_dict)\n\n        adapter_weights = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": torch.ones(\n                num_experts, 1, 4\n            )\n            * 0.1,\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": torch.ones(\n                num_experts, 8, 1\n            ),\n        }\n        _create_adapter(adapter_dir, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            merge_strategy=\"shard\",\n        )\n\n        out_tensors = _load_output_tensors(output_dir)\n        for i in range(num_experts):\n            w = out_tensors[f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\"]\n            assert torch.allclose(w, torch.full((8, 4), 0.1), atol=1e-6), f\"Expert {i} incorrect\"\n\n\n# ---------------------------------------------------------------------------\n# Shard-by-shard with fused experts (concatenated layout)\n# ---------------------------------------------------------------------------\n\n\nclass TestBuildShardedFusedExperts:\n    \"\"\"Shard-by-shard merge with fused gate_up_proj (concatenated layout).\"\"\"\n\n    NUM_EXPERTS = 2\n    IN_DIM = 4\n    OUT_DIM = 4\n    FUSED_DIM = OUT_DIM * 2\n\n    def test_merges_gate_and_up_into_correct_halves(self, tmp_path: Path):\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n\n        config = {\"architectures\": [\"TestMoEModel\"], \"model_type\": \"test\"}\n        state_dict = {\n            \"model.layers.0.mlp.experts.gate_up_proj\": torch.zeros(\n                self.NUM_EXPERTS, self.IN_DIM, self.FUSED_DIM, dtype=torch.float32\n            ),\n        }\n        _create_synthetic_model(model_dir, config, state_dict)\n\n        # Adapter for gate (w1) and up (w3) projections\n        prefix = \"base_model.model.model.layers.0.mlp.experts\"\n        gate_fill, up_fill = 0.02, 0.07\n        adapter_weights = {\n            f\"{prefix}.w1.lora_A.weight\": torch.ones(self.NUM_EXPERTS, 1, self.IN_DIM) * gate_fill,\n            f\"{prefix}.w1.lora_B.weight\": torch.ones(self.NUM_EXPERTS, self.OUT_DIM, 1),\n            f\"{prefix}.w3.lora_A.weight\": torch.ones(self.NUM_EXPERTS, 1, self.IN_DIM) * up_fill,\n            f\"{prefix}.w3.lora_B.weight\": torch.ones(self.NUM_EXPERTS, self.OUT_DIM, 1),\n        }\n        _create_adapter(adapter_dir, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            merge_strategy=\"shard\",\n        )\n\n        out_tensors = _load_output_tensors(output_dir)\n        fused = out_tensors[\"model.layers.0.mlp.experts.gate_up_proj\"]\n        sz = self.FUSED_DIM // 2\n        gate_half = fused[:, :, :sz]\n        up_half = fused[:, :, sz:]\n\n        assert torch.allclose(gate_half, torch.full_like(gate_half, gate_fill), atol=1e-6)\n        assert torch.allclose(up_half, torch.full_like(up_half, up_fill), atol=1e-6)\n\n    def test_single_shard_output_has_no_index(self, tmp_path: Path):\n        \"\"\"Single-shard models should produce model.safetensors without an index.\"\"\"\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n\n        config = {\"architectures\": [\"TestModel\"], \"model_type\": \"test\"}\n        state_dict = {\"model.layers.0.proj.weight\": torch.zeros(4, 4, dtype=torch.float32)}\n        _create_synthetic_model(model_dir, config, state_dict)\n\n        adapter_weights = {\n            \"base_model.model.model.layers.0.proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.proj.lora_B.weight\": torch.ones(4, 1),\n        }\n        _create_adapter(adapter_dir, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            merge_strategy=\"shard\",\n        )\n\n        assert (output_dir / \"model.safetensors\").exists()\n        assert not (output_dir / \"model.safetensors.index.json\").exists()\n\n\n# ---------------------------------------------------------------------------\n# Cleanup on failure\n# ---------------------------------------------------------------------------\n\n\nclass TestBuildShardedCleanup:\n    def test_cleans_up_on_failure(self, tmp_path: Path):\n        output_dir = tmp_path / \"output\"\n        adapter_dir = tmp_path / \"adapter\"\n\n        # Create adapter but no model — will fail when trying to resolve model dir\n        _create_adapter(adapter_dir, {\"x.lora_A.weight\": torch.zeros(1)}, {\"lora_alpha\": 1, \"r\": 1})\n\n        with pytest.raises(Exception):  # noqa: B017\n            build_hf_model(\n                base_model=str(tmp_path / \"nonexistent_model\"),\n                adapter_path=str(adapter_dir),\n                output_path=str(output_dir),\n                merge_strategy=\"shard\",\n            )\n\n        # Output dir should not exist after cleanup\n        assert not output_dir.exists()\n\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _load_output_tensors(output_dir: Path) -> dict[str, torch.Tensor]:\n    \"\"\"Load all tensors from an output directory (single or sharded).\"\"\"\n    single = output_dir / \"model.safetensors\"\n    if single.exists():\n        return load_file(str(single))\n\n    index_path = output_dir / \"model.safetensors.index.json\"\n    assert index_path.exists(), f\"No model.safetensors or index.json in {output_dir}\"\n    with open(index_path) as f:\n        weight_map = json.load(f)[\"weight_map\"]\n\n    tensors: dict[str, torch.Tensor] = {}\n    for shard_name in sorted(set(weight_map.values())):\n        tensors.update(load_file(str(output_dir / shard_name)))\n    return tensors\n"
  },
  {
    "path": "tinker_cookbook/weights/merge_test.py",
    "content": "\"\"\"Unit tests for LoRA merge logic.\n\nUses synthetic tensors to cover all code paths without needing real models\nor network access.\n\"\"\"\n\nfrom typing import Any\n\nimport pytest\nimport torch\n\nfrom tinker_cookbook.exceptions import WeightsMergeError\nfrom tinker_cookbook.weights._merge import (\n    MergeOp,\n    MergeProfile,\n    apply_merge_op,\n    apply_merged_weight,\n    detect_merge_profile,\n    expand_expert_lora_tensors,\n    merge_adapter_weights,\n    merge_lora_matrices,\n    plan_merge_ops,\n    validate_merge_op_shapes,\n)\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _make_base_model(state_dict: dict[str, torch.Tensor], class_name: str = \"SomeModel\") -> Any:\n    \"\"\"Create a minimal mock model with a real state_dict and controllable class name.\n\n    Uses a dynamically-created class so ``str(type(model))`` contains the\n    desired class name (important for GPT-OSS detection).\n    \"\"\"\n    cls = type(class_name, (), {\"state_dict\": lambda self: state_dict})\n    return cls()\n\n\ndef _make_expert_lora_pair(\n    num_experts: int, out_dim: int, in_dim: int, rank: int = 1, fill: float = 1.0\n) -> tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"Create LoRA A/B pair for experts with predictable merged output.\n\n    lora_A = fill * ones, lora_B = ones → merged = fill * ones(in_dim, out_dim) * rank.\n    \"\"\"\n    lora_A = torch.ones(num_experts, rank, in_dim) * fill\n    lora_B = torch.ones(num_experts, out_dim, rank)\n    return lora_A, lora_B\n\n\n# ---------------------------------------------------------------------------\n# apply_merged_weight\n# ---------------------------------------------------------------------------\n\n\nclass TestApplyMergedWeight:\n    def test_adds_delta_in_place(self):\n        target = torch.zeros(3, 4)\n        delta = torch.ones(3, 4) * 0.5\n        apply_merged_weight(target, delta)\n        assert torch.allclose(target, torch.full((3, 4), 0.5))\n\n    def test_raises_on_shape_mismatch(self):\n        with pytest.raises(WeightsMergeError, match=\"Shape mismatch\"):\n            apply_merged_weight(torch.zeros(3, 4), torch.zeros(3, 5))\n\n\n# ---------------------------------------------------------------------------\n# Config validation\n# ---------------------------------------------------------------------------\n\n\nclass TestConfigValidation:\n    def test_missing_lora_alpha(self):\n        model = _make_base_model({})\n        with pytest.raises(WeightsMergeError, match=\"lora_alpha\"):\n            merge_adapter_weights(model, {}, {\"r\": 1})\n\n    def test_missing_r(self):\n        model = _make_base_model({})\n        with pytest.raises(WeightsMergeError, match=\"'r'\"):\n            merge_adapter_weights(model, {}, {\"lora_alpha\": 1})\n\n\n# ---------------------------------------------------------------------------\n# Non-expert linear layers\n# ---------------------------------------------------------------------------\n\n\nclass TestNonExpertMerge:\n    def test_standard_linear_merge(self):\n        state_dict = {\"model.layers.0.self_attn.q_proj.weight\": torch.zeros(8, 4)}\n        model = _make_base_model(state_dict)\n\n        # rank=1 LoRA: A=(1,4), B=(8,1) → merged=(8,4) all equal to fill*scaling\n        adapter_weights = {\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n\n        merge_adapter_weights(model, adapter_weights, {\"lora_alpha\": 2, \"r\": 1})\n\n        result = state_dict[\"model.layers.0.self_attn.q_proj.weight\"]\n        # merged = linear(A.T, B*scaling).T = linear((4,1), (8,1)*2).T\n        # = (4,8)*2 transposed... actually let's just check it's nonzero and uniform\n        assert result.abs().sum() > 0\n        assert torch.allclose(result, result[0, 0].expand_as(result))\n\n    def test_missing_target_key_raises(self):\n        model = _make_base_model({\"some.other.weight\": torch.zeros(4, 4)})\n        adapter_weights = {\n            \"base_model.model.model.layers.0.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.q_proj.lora_B.weight\": torch.ones(4, 1),\n        }\n        with pytest.raises(WeightsMergeError, match=\"does not exist in the model state dict\"):\n            merge_adapter_weights(model, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n    def test_gpt_oss_attn_remapping(self):\n        state_dict = {\"model.layers.0.self_attn.q_proj.weight\": torch.zeros(8, 4)}\n        model = _make_base_model(state_dict, class_name=\"GptOssForCausalLM\")\n\n        # Tinker adapter uses .attn instead of .self_attn for GPT-OSS\n        adapter_weights = {\n            \"base_model.model.model.layers.0.attn.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n\n        merge_adapter_weights(model, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n        assert state_dict[\"model.layers.0.self_attn.q_proj.weight\"].abs().sum() > 0\n\n    def test_vision_model_prefix_remapping(self):\n        state_dict = {\"model.language_model.layers.0.self_attn.q_proj.weight\": torch.zeros(8, 4)}\n        model = _make_base_model(state_dict)\n\n        adapter_weights = {\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n\n        merge_adapter_weights(model, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n        assert state_dict[\"model.language_model.layers.0.self_attn.q_proj.weight\"].abs().sum() > 0\n\n\n# ---------------------------------------------------------------------------\n# Separate per-expert weights (Qwen3 MoE, DeepSeek, Kimi)\n# ---------------------------------------------------------------------------\n\n\nclass TestSeparateExpertMerge:\n    def test_per_expert_merge(self):\n        num_experts = 2\n        state_dict = {\n            f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\": torch.zeros(8, 4)\n            for i in range(num_experts)\n        }\n        model = _make_base_model(state_dict)\n\n        gate_A, gate_B = _make_expert_lora_pair(num_experts, 8, 4, fill=0.1)\n        adapter_weights = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": gate_A,\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": gate_B,\n        }\n\n        merge_adapter_weights(model, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n        for i in range(num_experts):\n            w = state_dict[f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\"]\n            assert w.abs().sum() > 0, f\"Expert {i} was not updated\"\n\n    def test_shared_lora_a_broadcast(self):\n        \"\"\"lora_A has 1 expert, lora_B has N — A should be broadcast.\"\"\"\n        num_experts = 3\n        state_dict = {\n            f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\": torch.zeros(8, 4)\n            for i in range(num_experts)\n        }\n        model = _make_base_model(state_dict)\n\n        lora_A = torch.ones(1, 1, 4) * 0.5  # shared across experts\n        lora_B = torch.ones(num_experts, 8, 1)\n        adapter_weights = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": lora_A,\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": lora_B,\n        }\n\n        merge_adapter_weights(model, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n        for i in range(num_experts):\n            assert state_dict[f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\"].abs().sum() > 0\n\n\n# ---------------------------------------------------------------------------\n# Fused expert weights — interleaved (GPT-OSS)\n# ---------------------------------------------------------------------------\n\n\nclass TestFusedInterleavedMerge:\n    \"\"\"GPT-OSS: gate_up_proj uses [g0, u0, g1, u1, ...] layout.\"\"\"\n\n    NUM_EXPERTS = 2\n    IN_DIM = 4\n    OUT_DIM = 4\n    FUSED_DIM = OUT_DIM * 2\n\n    def _make_state_dict(self) -> dict[str, torch.Tensor]:\n        return {\n            \"model.layers.0.mlp.experts.gate_up_proj\": torch.zeros(\n                self.NUM_EXPERTS, self.IN_DIM, self.FUSED_DIM\n            ),\n            \"model.layers.0.mlp.experts.down_proj\": torch.zeros(\n                self.NUM_EXPERTS, self.OUT_DIM, self.IN_DIM\n            ),\n        }\n\n    def _make_adapter(self, gate_fill: float, up_fill: float) -> dict[str, torch.Tensor]:\n        prefix = \"base_model.model.model.layers.0.mlp.experts\"\n        gate_A, gate_B = _make_expert_lora_pair(\n            self.NUM_EXPERTS, self.OUT_DIM, self.IN_DIM, fill=gate_fill\n        )\n        up_A, up_B = _make_expert_lora_pair(\n            self.NUM_EXPERTS, self.OUT_DIM, self.IN_DIM, fill=up_fill\n        )\n        return {\n            f\"{prefix}.w1.lora_A.weight\": gate_A,\n            f\"{prefix}.w1.lora_B.weight\": gate_B,\n            f\"{prefix}.w3.lora_A.weight\": up_A,\n            f\"{prefix}.w3.lora_B.weight\": up_B,\n        }\n\n    def test_gate_and_up_in_correct_slots(self):\n        state_dict = self._make_state_dict()\n        model = _make_base_model(state_dict, class_name=\"GptOssModel\")\n        adapter = self._make_adapter(gate_fill=0.01, up_fill=0.05)\n\n        merge_adapter_weights(model, adapter, {\"lora_alpha\": 1, \"r\": 1})\n\n        fused = state_dict[\"model.layers.0.mlp.experts.gate_up_proj\"]\n        gate_slots = fused[:, :, 0::2]\n        up_slots = fused[:, :, 1::2]\n\n        assert torch.allclose(gate_slots, torch.full_like(gate_slots, 0.01), atol=1e-6)\n        assert torch.allclose(up_slots, torch.full_like(up_slots, 0.05), atol=1e-6)\n\n    def test_up_does_not_leak_into_gate(self):\n        state_dict = self._make_state_dict()\n        model = _make_base_model(state_dict, class_name=\"GptOssModel\")\n\n        prefix = \"base_model.model.model.layers.0.mlp.experts\"\n        up_A, up_B = _make_expert_lora_pair(self.NUM_EXPERTS, self.OUT_DIM, self.IN_DIM, fill=0.1)\n        adapter = {\n            f\"{prefix}.w3.lora_A.weight\": up_A,\n            f\"{prefix}.w3.lora_B.weight\": up_B,\n        }\n\n        merge_adapter_weights(model, adapter, {\"lora_alpha\": 1, \"r\": 1})\n\n        fused = state_dict[\"model.layers.0.mlp.experts.gate_up_proj\"]\n        assert fused[:, :, 0::2].abs().max() == 0.0, \"up delta leaked into gate slots\"\n        assert fused[:, :, 1::2].abs().sum() > 0\n\n\n# ---------------------------------------------------------------------------\n# Fused expert weights — concatenated (Qwen3.5, Qwen3-VL)\n# ---------------------------------------------------------------------------\n\n\nclass TestFusedConcatenatedMerge:\n    \"\"\"Non-GPT-OSS fused: gate_up_proj uses [gate | up] layout.\"\"\"\n\n    NUM_EXPERTS = 2\n    IN_DIM = 4\n    OUT_DIM = 4\n    FUSED_DIM = OUT_DIM * 2\n\n    def _make_state_dict(self) -> dict[str, torch.Tensor]:\n        return {\n            \"model.layers.0.mlp.experts.gate_up_proj\": torch.zeros(\n                self.NUM_EXPERTS, self.IN_DIM, self.FUSED_DIM\n            ),\n            \"model.layers.0.mlp.experts.down_proj\": torch.zeros(\n                self.NUM_EXPERTS, self.OUT_DIM, self.IN_DIM\n            ),\n        }\n\n    def _make_adapter(self, gate_fill: float, up_fill: float) -> dict[str, torch.Tensor]:\n        prefix = \"base_model.model.model.layers.0.mlp.experts\"\n        gate_A, gate_B = _make_expert_lora_pair(\n            self.NUM_EXPERTS, self.OUT_DIM, self.IN_DIM, fill=gate_fill\n        )\n        up_A, up_B = _make_expert_lora_pair(\n            self.NUM_EXPERTS, self.OUT_DIM, self.IN_DIM, fill=up_fill\n        )\n        return {\n            f\"{prefix}.w1.lora_A.weight\": gate_A,\n            f\"{prefix}.w1.lora_B.weight\": gate_B,\n            f\"{prefix}.w3.lora_A.weight\": up_A,\n            f\"{prefix}.w3.lora_B.weight\": up_B,\n        }\n\n    def test_gate_and_up_in_correct_halves(self):\n        state_dict = self._make_state_dict()\n        model = _make_base_model(state_dict, class_name=\"QwenModel\")\n        adapter = self._make_adapter(gate_fill=0.02, up_fill=0.07)\n\n        merge_adapter_weights(model, adapter, {\"lora_alpha\": 1, \"r\": 1})\n\n        fused = state_dict[\"model.layers.0.mlp.experts.gate_up_proj\"]\n        sz = self.FUSED_DIM // 2\n        gate_half = fused[:, :, :sz]\n        up_half = fused[:, :, sz:]\n\n        assert torch.allclose(gate_half, torch.full_like(gate_half, 0.02), atol=1e-6)\n        assert torch.allclose(up_half, torch.full_like(up_half, 0.07), atol=1e-6)\n\n    def test_up_does_not_leak_into_gate(self):\n        state_dict = self._make_state_dict()\n        model = _make_base_model(state_dict, class_name=\"QwenModel\")\n\n        prefix = \"base_model.model.model.layers.0.mlp.experts\"\n        up_A, up_B = _make_expert_lora_pair(self.NUM_EXPERTS, self.OUT_DIM, self.IN_DIM, fill=0.1)\n        adapter = {\n            f\"{prefix}.w3.lora_A.weight\": up_A,\n            f\"{prefix}.w3.lora_B.weight\": up_B,\n        }\n\n        merge_adapter_weights(model, adapter, {\"lora_alpha\": 1, \"r\": 1})\n\n        fused = state_dict[\"model.layers.0.mlp.experts.gate_up_proj\"]\n        sz = self.FUSED_DIM // 2\n        assert fused[:, :, :sz].abs().max() == 0.0, \"up delta leaked into gate half\"\n        assert fused[:, :, sz:].abs().sum() > 0\n\n\n# ---------------------------------------------------------------------------\n# Error cases for expert LoRA\n# ---------------------------------------------------------------------------\n\n\nclass TestExpertErrorCases:\n    def test_non_3d_expert_lora_raises(self):\n        state_dict = {\"model.layers.0.mlp.experts.0.gate_proj.weight\": torch.zeros(8, 4)}\n        model = _make_base_model(state_dict)\n\n        adapter_weights = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": torch.ones(1, 4),  # 2D!\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": torch.ones(8, 1),  # 2D!\n        }\n        with pytest.raises(WeightsMergeError, match=\"must be 3D\"):\n            merge_adapter_weights(model, adapter_weights, {\"lora_alpha\": 1, \"r\": 1})\n\n\n# ===========================================================================\n# Tests for new APIs: MergeProfile, detect_merge_profile, plan/apply,\n# merge_lora_matrices, expand_expert_lora_tensors\n# ===========================================================================\n\n# ---------------------------------------------------------------------------\n# merge_lora_matrices\n# ---------------------------------------------------------------------------\n\n\nclass TestMergeLoraMatrices:\n    def test_basic_multiplication(self):\n        lora_A = torch.ones(1, 4)  # (rank=1, in_dim=4)\n        lora_B = torch.ones(8, 1)  # (out_dim=8, rank=1)\n        result = merge_lora_matrices(lora_A, lora_B)\n        assert result.shape == (8, 4)\n        assert torch.allclose(result, torch.ones(8, 4))\n\n    def test_with_scaling(self):\n        lora_A = torch.ones(2, 3) * 0.5\n        lora_B = torch.eye(3, 2)  # (3, 2)\n        result = merge_lora_matrices(lora_A, lora_B)\n        # eye(3,2) @ (ones(2,3) * 0.5) = first 2 rows are 0.5, last row is 0\n        assert result.shape == (3, 3)\n        assert torch.allclose(result[:2], torch.full((2, 3), 0.5))\n        assert torch.allclose(result[2], torch.zeros(3))\n\n\n# ---------------------------------------------------------------------------\n# expand_expert_lora_tensors\n# ---------------------------------------------------------------------------\n\n\nclass TestExpandExpertLoraTensors:\n    def test_expand_A_to_match_B(self):\n        lora_A = torch.ones(1, 2, 4)\n        lora_B = torch.ones(3, 8, 2)\n        out_A, out_B = expand_expert_lora_tensors(lora_A, lora_B)\n        assert out_A.shape[0] == 3\n        assert out_B is lora_B\n\n    def test_expand_B_to_match_A(self):\n        lora_A = torch.ones(3, 2, 4)\n        lora_B = torch.ones(1, 8, 2)\n        out_A, out_B = expand_expert_lora_tensors(lora_A, lora_B)\n        assert out_A is lora_A\n        assert out_B.shape[0] == 3\n\n    def test_both_single_raises(self):\n        with pytest.raises(WeightsMergeError, match=\"both A and B have 1 expert\"):\n            expand_expert_lora_tensors(torch.ones(1, 2, 4), torch.ones(1, 8, 2))\n\n    def test_already_matched_is_noop(self):\n        lora_A = torch.ones(4, 2, 4)\n        lora_B = torch.ones(4, 8, 2)\n        out_A, out_B = expand_expert_lora_tensors(lora_A, lora_B)\n        assert out_A is lora_A\n        assert out_B is lora_B\n\n    def test_mismatched_expert_counts_raises(self):\n        with pytest.raises(WeightsMergeError, match=\"Expert count mismatch\"):\n            expand_expert_lora_tensors(torch.ones(3, 2, 4), torch.ones(5, 8, 2))\n\n\n# ---------------------------------------------------------------------------\n# detect_merge_profile\n# ---------------------------------------------------------------------------\n\n\nclass TestDetectMergeProfile:\n    def test_standard_model(self):\n        config: dict = {\"architectures\": [\"QwenForCausalLM\"]}\n        keys = {\"model.layers.0.self_attn.q_proj.weight\"}\n        profile = detect_merge_profile(config, keys)\n        assert profile.expert_layout == \"separate\"\n        assert profile.extra_key_remaps == ()\n        assert profile.has_language_model_prefix is False\n\n    def test_gpt_oss_detection(self):\n        config: dict = {\"architectures\": [\"GptOssForCausalLM\"]}\n        keys = {\n            \"model.layers.0.self_attn.q_proj.weight\",\n            \"model.layers.0.mlp.experts.gate_up_proj\",\n        }\n        profile = detect_merge_profile(config, keys)\n        assert profile.expert_layout == \"fused_interleaved\"\n        assert (\".attn\", \".self_attn\") in profile.extra_key_remaps\n\n    def test_fused_concatenated_non_gpt_oss(self):\n        config: dict = {\"architectures\": [\"Qwen3ForCausalLM\"]}\n        keys = {\"model.layers.0.mlp.experts.gate_up_proj\"}\n        profile = detect_merge_profile(config, keys)\n        assert profile.expert_layout == \"fused_concatenated\"\n        assert profile.extra_key_remaps == ()\n\n    def test_vision_model_prefix(self):\n        config: dict = {\"architectures\": [\"Qwen3VLForConditionalGeneration\"]}\n        keys = {\"model.language_model.layers.0.self_attn.q_proj.weight\"}\n        profile = detect_merge_profile(config, keys)\n        assert profile.has_language_model_prefix is True\n\n    def test_separate_experts_without_fused(self):\n        config: dict = {\"architectures\": [\"QwenMoEForCausalLM\"]}\n        keys = {\"model.layers.0.mlp.experts.0.gate_proj.weight\"}\n        profile = detect_merge_profile(config, keys)\n        assert profile.expert_layout == \"separate\"\n\n    def test_empty_architectures(self):\n        profile = detect_merge_profile({}, {\"model.layers.0.weight\"})\n        assert profile.expert_layout == \"separate\"\n        assert profile.extra_key_remaps == ()\n\n\n# ---------------------------------------------------------------------------\n# plan_merge_ops\n# ---------------------------------------------------------------------------\n\n\nclass TestPlanMergeOps:\n    def test_non_expert_produces_2d_op(self):\n        keys = {\"model.layers.0.self_attn.q_proj.weight\"}\n        profile = MergeProfile()\n        adapter = {\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n        ops = plan_merge_ops(adapter, {\"lora_alpha\": 2, \"r\": 1}, keys, profile)\n        assert \"model.layers.0.self_attn.q_proj.weight\" in ops\n        op = ops[\"model.layers.0.self_attn.q_proj.weight\"][0]\n        assert op.lora_A.ndim == 2\n        assert op.is_expert_3d is False\n        # B should be pre-scaled by alpha/r = 2\n        assert torch.allclose(op.lora_B, torch.ones(8, 1) * 2)\n\n    def test_separate_experts_produce_per_expert_2d_ops(self):\n        keys = {\n            \"model.layers.0.mlp.experts.0.gate_proj.weight\",\n            \"model.layers.0.mlp.experts.1.gate_proj.weight\",\n        }\n        profile = MergeProfile(expert_layout=\"separate\")\n        adapter = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": torch.ones(2, 1, 4),\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": torch.ones(2, 8, 1),\n        }\n        ops = plan_merge_ops(adapter, {\"lora_alpha\": 1, \"r\": 1}, keys, profile)\n        # Should have ops for both experts\n        assert \"model.layers.0.mlp.experts.0.gate_proj.weight\" in ops\n        assert \"model.layers.0.mlp.experts.1.gate_proj.weight\" in ops\n        # Each op should have 2D tensors\n        for key in ops:\n            assert ops[key][0].lora_A.ndim == 2\n\n    def test_fused_experts_produce_3d_ops(self):\n        keys = {\"model.layers.0.mlp.experts.gate_up_proj\"}\n        profile = MergeProfile(expert_layout=\"fused_concatenated\")\n        adapter = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": torch.ones(2, 1, 4),\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": torch.ones(2, 8, 1),\n        }\n        ops = plan_merge_ops(adapter, {\"lora_alpha\": 1, \"r\": 1}, keys, profile)\n        assert \"model.layers.0.mlp.experts.gate_up_proj\" in ops\n        op = ops[\"model.layers.0.mlp.experts.gate_up_proj\"][0]\n        assert op.is_expert_3d is True\n        assert op.fused_proj_idx == 0  # gate = w1\n        assert op.fused_proj_interleaved is False\n\n    def test_fused_interleaved_sets_flag(self):\n        keys = {\"model.layers.0.mlp.experts.gate_up_proj\"}\n        profile = MergeProfile(expert_layout=\"fused_interleaved\")\n        adapter = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": torch.ones(2, 1, 4),\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": torch.ones(2, 8, 1),\n        }\n        ops = plan_merge_ops(adapter, {\"lora_alpha\": 1, \"r\": 1}, keys, profile)\n        op = ops[\"model.layers.0.mlp.experts.gate_up_proj\"][0]\n        assert op.fused_proj_interleaved is True\n\n    def test_extra_key_remaps_applied(self):\n        keys = {\"model.layers.0.self_attn.q_proj.weight\"}\n        profile = MergeProfile(extra_key_remaps=((\".attn\", \".self_attn\"),))\n        adapter = {\n            \"base_model.model.model.layers.0.attn.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n        ops = plan_merge_ops(adapter, {\"lora_alpha\": 1, \"r\": 1}, keys, profile)\n        assert \"model.layers.0.self_attn.q_proj.weight\" in ops\n\n    def test_vision_prefix_remapping(self):\n        keys = {\"model.language_model.layers.0.self_attn.q_proj.weight\"}\n        profile = MergeProfile(has_language_model_prefix=True)\n        adapter = {\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n        ops = plan_merge_ops(adapter, {\"lora_alpha\": 1, \"r\": 1}, keys, profile)\n        assert \"model.language_model.layers.0.self_attn.q_proj.weight\" in ops\n\n    def test_unembed_tokens_remapped_to_lm_head(self):\n        keys = {\"lm_head.weight\"}\n        profile = MergeProfile()\n        adapter = {\n            \"base_model.model.model.unembed_tokens.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.unembed_tokens.lora_B.weight\": torch.ones(8, 1),\n        }\n        ops = plan_merge_ops(adapter, {\"lora_alpha\": 1, \"r\": 1}, keys, profile)\n        assert \"lm_head.weight\" in ops\n\n    def test_missing_key_raises(self):\n        keys = {\"some.other.weight\"}\n        profile = MergeProfile()\n        adapter = {\n            \"base_model.model.model.layers.0.q_proj.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.model.layers.0.q_proj.lora_B.weight\": torch.ones(4, 1),\n        }\n        with pytest.raises(WeightsMergeError, match=\"does not exist\"):\n            plan_merge_ops(adapter, {\"lora_alpha\": 1, \"r\": 1}, keys, profile)\n\n    def test_missing_config_key_raises(self):\n        with pytest.raises(WeightsMergeError, match=\"lora_alpha\"):\n            plan_merge_ops({}, {\"r\": 1}, set(), MergeProfile())\n\n    def test_invalid_expert_layout_raises(self):\n        with pytest.raises(WeightsMergeError, match=\"Invalid expert_layout\"):\n            plan_merge_ops({}, {\"lora_alpha\": 1, \"r\": 1}, set(), MergeProfile(expert_layout=\"bad\"))\n\n\n# ---------------------------------------------------------------------------\n# apply_merge_op\n# ---------------------------------------------------------------------------\n\n\nclass TestApplyMergeOp:\n    def test_2d_standard_merge(self):\n        tensors = {\"q_proj.weight\": torch.zeros(8, 4)}\n        op = MergeOp(\n            target_key=\"q_proj.weight\",\n            lora_A=torch.ones(1, 4),\n            lora_B=torch.ones(8, 1),\n        )\n        apply_merge_op(tensors, op)\n        assert tensors[\"q_proj.weight\"].abs().sum() > 0\n        # delta = B @ A = ones(8,1) @ ones(1,4) = ones(8,4)\n        assert torch.allclose(tensors[\"q_proj.weight\"], torch.ones(8, 4))\n\n    def test_3d_fused_concatenated_gate(self):\n        n_exp, in_dim, fused_dim = 2, 4, 8\n        tensors = {\"gate_up_proj\": torch.zeros(n_exp, in_dim, fused_dim)}\n        # Gate op (fused_proj_idx=0): delta goes into first half\n        lora_A = torch.ones(n_exp, 1, in_dim) * 0.1\n        lora_B = torch.ones(n_exp, fused_dim // 2, 1)\n        op = MergeOp(\n            target_key=\"gate_up_proj\",\n            lora_A=lora_A,\n            lora_B=lora_B,\n            is_expert_3d=True,\n            fused_proj_idx=0,\n            fused_proj_interleaved=False,\n        )\n        apply_merge_op(tensors, op)\n        gate_half = tensors[\"gate_up_proj\"][:, :, : fused_dim // 2]\n        up_half = tensors[\"gate_up_proj\"][:, :, fused_dim // 2 :]\n        assert gate_half.abs().sum() > 0\n        assert up_half.abs().sum() == 0\n\n    def test_3d_fused_interleaved_up(self):\n        n_exp, in_dim, fused_dim = 2, 4, 8\n        tensors = {\"gate_up_proj\": torch.zeros(n_exp, in_dim, fused_dim)}\n        # Up op (fused_proj_idx=1, interleaved): delta goes into odd columns\n        lora_A = torch.ones(n_exp, 1, in_dim) * 0.2\n        lora_B = torch.ones(n_exp, fused_dim // 2, 1)\n        op = MergeOp(\n            target_key=\"gate_up_proj\",\n            lora_A=lora_A,\n            lora_B=lora_B,\n            is_expert_3d=True,\n            fused_proj_idx=1,\n            fused_proj_interleaved=True,\n        )\n        apply_merge_op(tensors, op)\n        gate_slots = tensors[\"gate_up_proj\"][:, :, 0::2]\n        up_slots = tensors[\"gate_up_proj\"][:, :, 1::2]\n        assert gate_slots.abs().sum() == 0\n        assert up_slots.abs().sum() > 0\n\n    def test_shape_mismatch_raises(self):\n        tensors = {\"weight\": torch.zeros(4, 4)}\n        op = MergeOp(\n            target_key=\"weight\",\n            lora_A=torch.ones(1, 8),  # wrong in_dim\n            lora_B=torch.ones(4, 1),\n        )\n        with pytest.raises(WeightsMergeError, match=\"Shape mismatch\"):\n            apply_merge_op(tensors, op)\n\n\n# ---------------------------------------------------------------------------\n# validate_merge_op_shapes\n# ---------------------------------------------------------------------------\n\n\nclass TestValidateMergeOpShapes:\n    def test_valid_2d_op_passes(self):\n        ops = {\n            \"q_proj.weight\": [\n                MergeOp(\n                    target_key=\"q_proj.weight\", lora_A=torch.ones(1, 4), lora_B=torch.ones(8, 1)\n                )\n            ]\n        }\n        shapes = {\"q_proj.weight\": (8, 4)}\n        validate_merge_op_shapes(ops, shapes)  # should not raise\n\n    def test_invalid_2d_shape_raises(self):\n        ops = {\n            \"q_proj.weight\": [\n                MergeOp(\n                    target_key=\"q_proj.weight\", lora_A=torch.ones(1, 8), lora_B=torch.ones(4, 1)\n                )\n            ]\n        }\n        shapes = {\"q_proj.weight\": (8, 4)}  # target is (8,4) but delta is (4,8)\n        with pytest.raises(WeightsMergeError, match=r\"Shape mismatch.*q_proj\"):\n            validate_merge_op_shapes(ops, shapes)\n\n    def test_valid_3d_fused_concatenated_passes(self):\n        ops = {\n            \"gate_up_proj\": [\n                MergeOp(\n                    target_key=\"gate_up_proj\",\n                    lora_A=torch.ones(2, 1, 4),\n                    lora_B=torch.ones(2, 4, 1),\n                    is_expert_3d=True,\n                    fused_proj_idx=0,\n                    fused_proj_interleaved=False,\n                )\n            ]\n        }\n        # Target is (2, 4, 8) — fused gate+up, each half is (2, 4, 4)\n        # Delta via bmm is (2, 4, 4) which matches the half\n        shapes = {\"gate_up_proj\": (2, 4, 8)}\n        validate_merge_op_shapes(ops, shapes)  # should not raise\n\n    def test_invalid_3d_fused_shape_raises(self):\n        ops = {\n            \"gate_up_proj\": [\n                MergeOp(\n                    target_key=\"gate_up_proj\",\n                    lora_A=torch.ones(2, 1, 4),\n                    lora_B=torch.ones(2, 6, 1),  # wrong out_dim\n                    is_expert_3d=True,\n                    fused_proj_idx=0,\n                    fused_proj_interleaved=False,\n                )\n            ]\n        }\n        shapes = {\"gate_up_proj\": (2, 4, 8)}  # half is (2, 4, 4), delta is (2, 4, 6)\n        with pytest.raises(WeightsMergeError, match=r\"Shape mismatch.*gate_up_proj\"):\n            validate_merge_op_shapes(ops, shapes)\n\n    def test_valid_3d_non_fused_passes(self):\n        ops = {\n            \"down_proj\": [\n                MergeOp(\n                    target_key=\"down_proj\",\n                    lora_A=torch.ones(2, 1, 4),\n                    lora_B=torch.ones(2, 8, 1),\n                    is_expert_3d=True,\n                    fused_proj_idx=None,\n                )\n            ]\n        }\n        shapes = {\"down_proj\": (2, 4, 8)}\n        validate_merge_op_shapes(ops, shapes)  # should not raise\n"
  },
  {
    "path": "tinker_cookbook/weights/publish_test.py",
    "content": "\"\"\"Tests for publish_to_hf_hub.\"\"\"\n\nimport tempfile\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\n\nfrom tinker_cookbook.weights import publish_to_hf_hub\n\n\nclass TestPublishToHfHub:\n    def test_raises_on_nonexistent_path(self):\n        with pytest.raises(FileNotFoundError, match=\"does not exist\"):\n            publish_to_hf_hub(\n                model_path=\"/nonexistent/path\",\n                repo_id=\"user/model\",\n            )\n\n    def test_calls_hf_api_correctly(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            mock_api = MagicMock()\n\n            with patch(\"tinker_cookbook.weights._publish.HfApi\", return_value=mock_api):\n                url = publish_to_hf_hub(\n                    model_path=tmpdir,\n                    repo_id=\"user/my-model\",\n                    private=True,\n                )\n\n            mock_api.create_repo.assert_called_once_with(\n                repo_id=\"user/my-model\",\n                repo_type=\"model\",\n                private=True,\n                exist_ok=True,\n            )\n            mock_api.upload_folder.assert_called_once_with(\n                folder_path=tmpdir,\n                repo_id=\"user/my-model\",\n                repo_type=\"model\",\n            )\n            assert url == \"https://huggingface.co/user/my-model\"\n\n    def test_public_repo(self):\n        with tempfile.TemporaryDirectory() as tmpdir:\n            mock_api = MagicMock()\n\n            with patch(\"tinker_cookbook.weights._publish.HfApi\", return_value=mock_api):\n                publish_to_hf_hub(\n                    model_path=tmpdir,\n                    repo_id=\"org/public-model\",\n                    private=False,\n                )\n\n            mock_api.create_repo.assert_called_once_with(\n                repo_id=\"org/public-model\",\n                repo_type=\"model\",\n                private=False,\n                exist_ok=True,\n            )\n"
  },
  {
    "path": "tinker_cookbook/weights/quantized_test.py",
    "content": "\"\"\"Unit tests for quantized export strategy.\n\nCovers FP8 math, DeepSeek detection, weight classification, vLLM config\ngeneration, resume state, and output shard assembly. Uses synthetic data —\nno network or GPU required.\n\"\"\"\n\nimport json\nfrom pathlib import Path\n\nimport pytest\nimport torch\nfrom safetensors.torch import load_file, save_file\n\nfrom tinker_cookbook.exceptions import WeightsMergeError\nfrom tinker_cookbook.weights._export._quantized import (\n    _build_vllm_quantization_config,\n    _is_routed_expert_weight,\n    _load_resume_state,\n    _save_merge_state,\n    _save_shard_atomic,\n    _serialize_for_vllm,\n    _should_skip_checkpoint_key,\n    dequantize_blockwise,\n    is_deepseek_config,\n    quantize_blockwise,\n)\n\n# ---------------------------------------------------------------------------\n# FP8 quantize / dequantize round-trip\n# ---------------------------------------------------------------------------\n\n\nclass TestFP8RoundTrip:\n    def test_exact_round_trip_block_size_1(self):\n        \"\"\"Block size (1,1) should give exact round-trip for representable values.\"\"\"\n        tensor = torch.tensor([[1.0, -0.5], [0.25, 0.0]])\n        fp8, scale = quantize_blockwise(tensor, block_size=(1, 1))\n        assert fp8.dtype == torch.float8_e4m3fn\n        assert scale.dtype == torch.float32\n        recovered = dequantize_blockwise(fp8, scale, block_size=(1, 1))\n        assert torch.allclose(recovered.float(), tensor, atol=1e-2)\n\n    def test_padded_dimensions(self):\n        \"\"\"Tensor dimensions not divisible by block size should still round-trip.\"\"\"\n        tensor = torch.randn(5, 7)\n        fp8, scale = quantize_blockwise(tensor, block_size=(2, 3))\n        assert fp8.shape == (5, 7)\n        # Scale shape should be ceil(5/2) x ceil(7/3) = 3 x 3\n        assert scale.shape == (3, 3)\n        recovered = dequantize_blockwise(fp8, scale, block_size=(2, 3), dtype=torch.float32)\n        assert recovered.shape == (5, 7)\n        # Round-trip error should be small\n        assert torch.allclose(recovered, tensor, atol=0.2)\n\n    def test_large_block_preserves_shape(self):\n        \"\"\"Standard 128x128 block size with a realistic shape.\"\"\"\n        tensor = torch.randn(256, 384)\n        fp8, scale = quantize_blockwise(tensor)\n        assert fp8.shape == (256, 384)\n        assert scale.shape == (2, 3)  # ceil(256/128) x ceil(384/128)\n\n    def test_zeros_round_trip(self):\n        \"\"\"All-zero tensor should round-trip cleanly.\"\"\"\n        tensor = torch.zeros(4, 4)\n        fp8, scale = quantize_blockwise(tensor, block_size=(2, 2))\n        recovered = dequantize_blockwise(fp8, scale, block_size=(2, 2), dtype=torch.float32)\n        assert torch.allclose(recovered, tensor)\n\n\n# ---------------------------------------------------------------------------\n# DeepSeek detection\n# ---------------------------------------------------------------------------\n\n\nclass TestIsDeepseekConfig:\n    def test_deepseek_v3_detected(self):\n        assert is_deepseek_config({\"model_type\": \"deepseek_v3\"})\n\n    def test_non_deepseek_rejected(self):\n        assert not is_deepseek_config({\"model_type\": \"qwen2_moe\"})\n        assert not is_deepseek_config({\"model_type\": \"llama\"})\n        assert not is_deepseek_config({})\n\n    def test_similar_strings_rejected(self):\n        assert not is_deepseek_config({\"model_type\": \"deepseek\"})\n        assert not is_deepseek_config({\"model_type\": \"deepseek_v2\"})\n\n\n# ---------------------------------------------------------------------------\n# Weight classification\n# ---------------------------------------------------------------------------\n\n\nclass TestIsRoutedExpertWeight:\n    def test_routed_expert_matched(self):\n        assert _is_routed_expert_weight(\"model.layers.3.mlp.experts.42.gate_proj.weight\")\n        assert _is_routed_expert_weight(\"model.layers.0.mlp.experts.0.down_proj.weight\")\n\n    def test_shared_expert_rejected(self):\n        assert not _is_routed_expert_weight(\"model.layers.0.mlp.shared_experts.gate_proj.weight\")\n\n    def test_attention_rejected(self):\n        assert not _is_routed_expert_weight(\"model.layers.0.self_attn.q_proj.weight\")\n\n    def test_norm_rejected(self):\n        assert not _is_routed_expert_weight(\"model.layers.0.input_layernorm.weight\")\n\n    def test_embed_rejected(self):\n        assert not _is_routed_expert_weight(\"model.embed_tokens.weight\")\n\n\n# ---------------------------------------------------------------------------\n# Skip key logic\n# ---------------------------------------------------------------------------\n\n\nclass TestShouldSkipCheckpointKey:\n    def test_rotary_emb_skipped(self):\n        assert _should_skip_checkpoint_key(\"model.layers.0.self_attn.rotary_emb.inv_freq\")\n\n    def test_layer_61_skipped(self):\n        assert _should_skip_checkpoint_key(\"model.layers.61.self_attn.q_proj.weight\")\n        assert _should_skip_checkpoint_key(\"model.layers.61.mlp.experts.0.gate_proj.weight\")\n\n    def test_normal_layer_not_skipped(self):\n        assert not _should_skip_checkpoint_key(\"model.layers.0.self_attn.q_proj.weight\")\n        assert not _should_skip_checkpoint_key(\"model.layers.60.mlp.gate_proj.weight\")\n\n    def test_embed_not_skipped(self):\n        assert not _should_skip_checkpoint_key(\"model.embed_tokens.weight\")\n\n\n# ---------------------------------------------------------------------------\n# vLLM quantization config\n# ---------------------------------------------------------------------------\n\n\nclass TestBuildVllmQuantizationConfig:\n    def test_correct_schema(self):\n        weight_map = {\n            \"model.layers.0.mlp.experts.0.gate_proj.weight\": \"shard-1.safetensors\",\n            \"model.layers.0.mlp.experts.0.gate_proj.weight_scale\": \"shard-1.safetensors\",\n            \"model.layers.0.self_attn.q_proj.weight\": \"shard-1.safetensors\",\n            \"model.embed_tokens.weight\": \"shard-1.safetensors\",\n        }\n        config = _build_vllm_quantization_config(weight_map)\n        assert config[\"quant_method\"] == \"compressed-tensors\"\n        assert config[\"format\"] == \"float-quantized\"\n        assert config[\"quantization_status\"] == \"compressed\"\n        assert \"config_groups\" in config\n        assert \"ignore\" in config\n        # Verify block quantization config\n        weights = config[\"config_groups\"][\"group_0\"][\"weights\"]\n        assert weights[\"strategy\"] == \"block\"\n        assert weights[\"block_structure\"] == [128, 128]\n        # Verify input activations\n        ia = config[\"config_groups\"][\"group_0\"][\"input_activations\"]\n        assert ia[\"dynamic\"] is True\n\n    def test_ignore_list_correct(self):\n        \"\"\"Dense projections should be in ignore, routed experts should NOT.\"\"\"\n        weight_map = {\n            \"model.layers.0.mlp.experts.0.gate_proj.weight\": \"s.safetensors\",\n            \"model.layers.0.mlp.experts.0.gate_proj.weight_scale\": \"s.safetensors\",\n            \"model.layers.0.self_attn.q_proj.weight\": \"s.safetensors\",\n            \"model.layers.0.mlp.shared_experts.gate_proj.weight\": \"s.safetensors\",\n        }\n        config = _build_vllm_quantization_config(weight_map)\n        ignore = config[\"ignore\"]\n        # Dense/shared should be in ignore\n        assert \"model.layers.0.self_attn.q_proj\" in ignore\n        assert \"model.layers.0.mlp.shared_experts.gate_proj\" in ignore\n        # Routed expert should NOT be in ignore\n        assert \"model.layers.0.mlp.experts.0.gate_proj\" not in ignore\n\n\nclass TestSerializeForVllm:\n    def test_strips_unknown_fields(self):\n        config = {\n            \"quant_method\": \"compressed-tensors\",\n            \"format\": \"float-quantized\",\n            \"unknown_field\": \"should be stripped\",\n            \"another_unknown\": 42,\n            \"ignore\": [],\n            \"config_groups\": {},\n        }\n        result = _serialize_for_vllm(config)\n        assert \"unknown_field\" not in result\n        assert \"another_unknown\" not in result\n        assert result[\"quant_method\"] == \"compressed-tensors\"\n        assert result[\"ignore\"] == []\n\n    def test_preserves_known_fields(self):\n        config = {\n            \"quant_method\": \"compressed-tensors\",\n            \"format\": \"float-quantized\",\n            \"quantization_status\": \"compressed\",\n            \"global_compression_ratio\": None,\n            \"config_groups\": {\n                \"group_0\": {\n                    \"targets\": [\"Linear\"],\n                    \"weights\": {\"num_bits\": 8, \"strategy\": \"block\"},\n                }\n            },\n            \"ignore\": [\"a.b\"],\n        }\n        result = _serialize_for_vllm(config)\n        assert result[\"quant_method\"] == \"compressed-tensors\"\n        assert result[\"ignore\"] == [\"a.b\"]\n        assert result[\"config_groups\"][\"group_0\"][\"targets\"] == [\"Linear\"]\n\n\n# ---------------------------------------------------------------------------\n# Resume state\n# ---------------------------------------------------------------------------\n\n\nclass TestResumeState:\n    def test_no_state_file_returns_empty(self, tmp_path: Path):\n        assert _load_resume_state(tmp_path) == {}\n\n    def test_load_valid_state(self, tmp_path: Path):\n        save_file({\"x\": torch.zeros(1)}, str(tmp_path / \"shard-1.safetensors\"))\n        _save_merge_state(\n            tmp_path,\n            status=\"in_progress\",\n            completed_shards=[\"shard-1.safetensors\"],\n            total_shards=2,\n        )\n        state = _load_resume_state(tmp_path)\n        assert state[\"status\"] == \"in_progress\"\n        assert state[\"completed_shards\"] == [\"shard-1.safetensors\"]\n        assert state[\"total_shards\"] == 2\n\n    def test_missing_shard_file_raises(self, tmp_path: Path):\n        \"\"\"Resume state references a shard that doesn't exist on disk.\"\"\"\n        _save_merge_state(\n            tmp_path,\n            status=\"in_progress\",\n            completed_shards=[\"missing.safetensors\"],\n            total_shards=1,\n        )\n        with pytest.raises(WeightsMergeError, match=\"not found\"):\n            _load_resume_state(tmp_path)\n\n    def test_atomic_save(self, tmp_path: Path):\n        \"\"\"Merge state should be saved atomically (no partial writes).\"\"\"\n        _save_merge_state(tmp_path, status=\"in_progress\", completed_shards=[], total_shards=3)\n        state_file = tmp_path / \"merge_state.json\"\n        assert state_file.exists()\n        # Temp file should not exist\n        assert not (tmp_path / \"merge_state.json.tmp\").exists()\n\n\nclass TestSaveShardAtomic:\n    def test_atomic_write(self, tmp_path: Path):\n        tensors = {\"a\": torch.ones(2, 3)}\n        _save_shard_atomic(tmp_path, \"shard-1.safetensors\", tensors)\n        assert (tmp_path / \"shard-1.safetensors\").exists()\n        assert not (tmp_path / \"shard-1.safetensors.tmp\").exists()\n        loaded = load_file(str(tmp_path / \"shard-1.safetensors\"))\n        assert torch.equal(loaded[\"a\"], torch.ones(2, 3))\n\n\n# ---------------------------------------------------------------------------\n# Output shard assembly (quantize behavior)\n# ---------------------------------------------------------------------------\n\n\nclass TestOutputShardAssembly:\n    \"\"\"Test that routed experts get FP8+scale and dense stays BF16.\"\"\"\n\n    def _make_model_and_adapter(self, tmp_path: Path):\n        \"\"\"Create a minimal DeepSeek-like model with one shard.\"\"\"\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        num_experts = 2\n\n        # Model state: routed experts + one dense weight + shared expert\n        state_dict = {}\n        for i in range(num_experts):\n            state_dict[f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\"] = torch.randn(\n                8, 4, dtype=torch.bfloat16\n            )\n        state_dict[\"model.layers.0.self_attn.q_proj.weight\"] = torch.randn(\n            8, 4, dtype=torch.bfloat16\n        )\n        state_dict[\"model.layers.0.mlp.shared_experts.gate_proj.weight\"] = torch.randn(\n            8, 4, dtype=torch.bfloat16\n        )\n        model_dir.mkdir(parents=True)\n        save_file(state_dict, str(model_dir / \"model.safetensors\"))\n        config = {\"model_type\": \"deepseek_v3\", \"architectures\": [\"DeepseekV3ForCausalLM\"]}\n        (model_dir / \"config.json\").write_text(json.dumps(config))\n        (model_dir / \"tokenizer_config.json\").write_text(\n            json.dumps({\"tokenizer_class\": \"PreTrainedTokenizerFast\"})\n        )\n        (model_dir / \"tokenizer.json\").write_text(\n            json.dumps(\n                {\n                    \"version\": \"1.0\",\n                    \"model\": {\"type\": \"BPE\", \"vocab\": {\"a\": 0, \"b\": 1}, \"merges\": []},\n                    \"added_tokens\": [],\n                }\n            )\n        )\n\n        # Adapter targeting the experts and dense weight\n        adapter_weights = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": torch.ones(\n                num_experts, 1, 4\n            )\n            * 0.01,\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": torch.ones(\n                num_experts, 8, 1\n            ),\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": torch.ones(1, 4)\n            * 0.01,\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.ones(8, 1),\n        }\n        adapter_dir.mkdir(parents=True)\n        save_file(adapter_weights, str(adapter_dir / \"adapter_model.safetensors\"))\n        (adapter_dir / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": 1}))\n\n        return model_dir, adapter_dir\n\n    def test_routed_expert_quantized_to_fp8(self, tmp_path: Path):\n        from tinker_cookbook.weights._export._quantized import build_quantized\n\n        model_dir, adapter_dir = self._make_model_and_adapter(tmp_path)\n        output_dir = tmp_path / \"output\"\n\n        build_quantized(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            trust_remote_code=False,\n            model_dir=model_dir,\n            config_dict=json.loads((model_dir / \"config.json\").read_text()),\n            serving_format=\"vllm\",\n        )\n\n        out_tensors = load_file(str(output_dir / \"model.safetensors\"))\n\n        # Routed expert should be FP8\n        expert_w = out_tensors[\"model.layers.0.mlp.experts.0.gate_proj.weight\"]\n        assert expert_w.dtype == torch.float8_e4m3fn\n\n        # Should have a scale tensor\n        assert \"model.layers.0.mlp.experts.0.gate_proj.weight_scale\" in out_tensors\n        scale = out_tensors[\"model.layers.0.mlp.experts.0.gate_proj.weight_scale\"]\n        assert scale.dtype == torch.float32\n\n    def test_dense_stays_bf16(self, tmp_path: Path):\n        from tinker_cookbook.weights._export._quantized import build_quantized\n\n        model_dir, adapter_dir = self._make_model_and_adapter(tmp_path)\n        output_dir = tmp_path / \"output\"\n\n        build_quantized(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            trust_remote_code=False,\n            model_dir=model_dir,\n            config_dict=json.loads((model_dir / \"config.json\").read_text()),\n            serving_format=\"vllm\",\n        )\n\n        out_tensors = load_file(str(output_dir / \"model.safetensors\"))\n\n        # Dense weight should stay BF16\n        q_proj = out_tensors[\"model.layers.0.self_attn.q_proj.weight\"]\n        assert q_proj.dtype == torch.bfloat16\n\n        # No scale tensor for dense weights\n        assert \"model.layers.0.self_attn.q_proj.weight_scale\" not in out_tensors\n\n    def test_shared_expert_stays_bf16(self, tmp_path: Path):\n        from tinker_cookbook.weights._export._quantized import build_quantized\n\n        model_dir, adapter_dir = self._make_model_and_adapter(tmp_path)\n        output_dir = tmp_path / \"output\"\n\n        build_quantized(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            trust_remote_code=False,\n            model_dir=model_dir,\n            config_dict=json.loads((model_dir / \"config.json\").read_text()),\n            serving_format=\"vllm\",\n        )\n\n        out_tensors = load_file(str(output_dir / \"model.safetensors\"))\n\n        shared = out_tensors[\"model.layers.0.mlp.shared_experts.gate_proj.weight\"]\n        assert shared.dtype == torch.bfloat16\n        assert \"model.layers.0.mlp.shared_experts.gate_proj.weight_scale\" not in out_tensors\n\n    def test_config_has_compression_config(self, tmp_path: Path):\n        from tinker_cookbook.weights._export._quantized import build_quantized\n\n        model_dir, adapter_dir = self._make_model_and_adapter(tmp_path)\n        output_dir = tmp_path / \"output\"\n\n        build_quantized(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            trust_remote_code=False,\n            model_dir=model_dir,\n            config_dict=json.loads((model_dir / \"config.json\").read_text()),\n            serving_format=\"vllm\",\n        )\n\n        config = json.loads((output_dir / \"config.json\").read_text())\n        assert \"compression_config\" in config\n        cc = config[\"compression_config\"]\n        assert cc[\"quant_method\"] == \"compressed-tensors\"\n        assert \"quantization_config\" not in config\n\n\n# ---------------------------------------------------------------------------\n# Cross-shard native FP8 scale handling\n# ---------------------------------------------------------------------------\n\n\nclass TestCrossShardFP8Scale:\n    \"\"\"Test that native FP8 weights are dequantized correctly even when\n    the weight and its scale_inv are in different shards.\"\"\"\n\n    def test_cross_shard_scale_dequantizes_correctly(self, tmp_path: Path):\n        from tinker_cookbook.weights._export._quantized import build_quantized, quantize_blockwise\n\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n        model_dir.mkdir(parents=True)\n\n        # Create a native FP8 expert weight and its scale, in DIFFERENT shards\n        original_weight = torch.randn(8, 4, dtype=torch.bfloat16)\n        fp8_weight, scale_inv = quantize_blockwise(original_weight, block_size=(4, 4))\n\n        # Shard 1: has the FP8 weight but NOT its scale\n        shard1 = {\n            \"model.layers.0.mlp.experts.0.gate_proj.weight\": fp8_weight,\n            \"model.layers.0.mlp.experts.1.gate_proj.weight\": fp8_weight.clone(),\n        }\n        # Shard 2: has the scale_inv but NOT the weight, plus another expert\n        shard2 = {\n            \"model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv\": scale_inv,\n            \"model.layers.0.mlp.experts.1.gate_proj.weight_scale_inv\": scale_inv.clone(),\n            \"model.layers.0.self_attn.q_proj.weight\": torch.randn(8, 4, dtype=torch.bfloat16),\n        }\n\n        save_file(shard1, str(model_dir / \"model-00001-of-00002.safetensors\"))\n        save_file(shard2, str(model_dir / \"model-00002-of-00002.safetensors\"))\n\n        weight_map = {}\n        for k in shard1:\n            weight_map[k] = \"model-00001-of-00002.safetensors\"\n        for k in shard2:\n            weight_map[k] = \"model-00002-of-00002.safetensors\"\n\n        total_size = sum(t.nelement() * t.element_size() for t in {**shard1, **shard2}.values())\n        index = {\"metadata\": {\"total_size\": total_size}, \"weight_map\": weight_map}\n        (model_dir / \"model.safetensors.index.json\").write_text(json.dumps(index))\n\n        # Config with native FP8 quantization\n        config = {\n            \"model_type\": \"deepseek_v3\",\n            \"architectures\": [\"DeepseekV3ForCausalLM\"],\n            \"quantization_config\": {\n                \"quant_method\": \"fp8\",\n                \"weight_block_size\": [4, 4],\n            },\n        }\n        (model_dir / \"config.json\").write_text(json.dumps(config))\n        (model_dir / \"tokenizer_config.json\").write_text(\n            json.dumps({\"tokenizer_class\": \"PreTrainedTokenizerFast\"})\n        )\n        (model_dir / \"tokenizer.json\").write_text(\n            json.dumps(\n                {\n                    \"version\": \"1.0\",\n                    \"model\": {\"type\": \"BPE\", \"vocab\": {\"a\": 0, \"b\": 1}, \"merges\": []},\n                    \"added_tokens\": [],\n                }\n            )\n        )\n\n        # Empty adapter (no merge, just quantize)\n        adapter_dir.mkdir(parents=True)\n        save_file({\"dummy\": torch.zeros(1)}, str(adapter_dir / \"adapter_model.safetensors\"))\n        (adapter_dir / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": 1}))\n\n        build_quantized(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            trust_remote_code=False,\n            model_dir=model_dir,\n            config_dict=config,\n            serving_format=\"vllm\",\n        )\n\n        # Verify: expert weights should be FP8 (re-quantized after dequant)\n        out = {}\n        for p in sorted(output_dir.glob(\"*.safetensors\")):\n            out.update(load_file(str(p)))\n\n        expert_key = \"model.layers.0.mlp.experts.0.gate_proj.weight\"\n        assert expert_key in out\n        assert out[expert_key].dtype == torch.float8_e4m3fn\n        # Scale should use compressed-tensors naming\n        assert \"model.layers.0.mlp.experts.0.gate_proj.weight_scale\" in out\n        # No native scale_inv in output\n        assert \"model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv\" not in out\n\n    def test_merge_applied_before_requantize_on_native_fp8(self, tmp_path: Path):\n        \"\"\"Regression: LoRA merge must happen AFTER dequant, BEFORE requant.\n\n        If merge is applied to the raw FP8 tensor (before dequant), the delta\n        gets corrupted because FP8 can't represent the fine-grained LoRA values.\n        \"\"\"\n        from tinker_cookbook.weights._export._quantized import (\n            build_quantized,\n            dequantize_blockwise,\n            quantize_blockwise,\n        )\n\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n        model_dir.mkdir(parents=True)\n\n        # Create a native FP8 expert weight\n        original_bf16 = torch.randn(8, 4, dtype=torch.bfloat16)\n        fp8_weight, scale_inv = quantize_blockwise(original_bf16, block_size=(4, 4))\n\n        num_experts = 2\n        shard1: dict[str, torch.Tensor] = {}\n        for i in range(num_experts):\n            shard1[f\"model.layers.0.mlp.experts.{i}.gate_proj.weight\"] = fp8_weight.clone()\n            shard1[f\"model.layers.0.mlp.experts.{i}.gate_proj.weight_scale_inv\"] = scale_inv.clone()\n\n        save_file(shard1, str(model_dir / \"model.safetensors\"))\n        config = {\n            \"model_type\": \"deepseek_v3\",\n            \"architectures\": [\"DeepseekV3ForCausalLM\"],\n            \"quantization_config\": {\"quant_method\": \"fp8\", \"weight_block_size\": [4, 4]},\n        }\n        (model_dir / \"config.json\").write_text(json.dumps(config))\n        (model_dir / \"tokenizer_config.json\").write_text(\n            json.dumps({\"tokenizer_class\": \"PreTrainedTokenizerFast\"})\n        )\n        (model_dir / \"tokenizer.json\").write_text(\n            json.dumps(\n                {\n                    \"version\": \"1.0\",\n                    \"model\": {\"type\": \"BPE\", \"vocab\": {\"a\": 0, \"b\": 1}, \"merges\": []},\n                    \"added_tokens\": [],\n                }\n            )\n        )\n\n        # LoRA adapter targeting gate_proj (w1) with a known delta\n        lora_fill = 0.1\n        adapter_weights = {\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_A.weight\": (\n                torch.ones(num_experts, 1, 4) * lora_fill\n            ),\n            \"base_model.model.model.layers.0.mlp.experts.w1.lora_B.weight\": torch.ones(\n                num_experts, 8, 1\n            ),\n        }\n        adapter_dir.mkdir(parents=True)\n        save_file(adapter_weights, str(adapter_dir / \"adapter_model.safetensors\"))\n        (adapter_dir / \"adapter_config.json\").write_text(json.dumps({\"lora_alpha\": 1, \"r\": 1}))\n\n        build_quantized(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            trust_remote_code=False,\n            model_dir=model_dir,\n            config_dict=config,\n            serving_format=\"vllm\",\n        )\n\n        # Load output and dequantize to check the merge was applied\n        out = {}\n        for p in sorted(output_dir.glob(\"*.safetensors\")):\n            out.update(load_file(str(p)))\n\n        expert_key = \"model.layers.0.mlp.experts.0.gate_proj.weight\"\n        scale_key = \"model.layers.0.mlp.experts.0.gate_proj.weight_scale\"\n        merged_dequantized = dequantize_blockwise(\n            out[expert_key], out[scale_key], block_size=(128, 128)\n        )\n\n        # The original dequantized value\n        original_dequantized = dequantize_blockwise(fp8_weight, scale_inv, block_size=(4, 4))\n\n        # The merged result should differ from original by approximately the LoRA delta\n        delta = (merged_dequantized.float() - original_dequantized.float()).abs()\n        assert delta.sum() > 0, \"LoRA merge had no effect — merge may have been applied to FP8\"\n        # The delta should be approximately lora_fill (0.1) everywhere\n        # Allow tolerance for FP8 round-trip quantization error\n        assert delta.mean().item() == pytest.approx(lora_fill, abs=0.05), (\n            f\"Expected delta ~{lora_fill}, got {delta.mean().item():.4f}. \"\n            \"Merge may have been applied before dequantization.\"\n        )\n"
  },
  {
    "path": "tinker_cookbook/weights/stress_test.py",
    "content": "\"\"\"Stress tests for the weights module.\n\nExercises edge cases, numerical correctness, and cross-path consistency\nto catch subtle bugs. Uses synthetic models — no network or GPU required.\n\"\"\"\n\nimport json\nimport logging\nfrom pathlib import Path\n\nimport pytest\nimport torch\nfrom safetensors.torch import load_file, save_file\n\nfrom tinker_cookbook.weights._artifacts import (\n    ShardWriter,\n    get_model_state_shapes,\n)\nfrom tinker_cookbook.weights._export import build_hf_model, load_config_dict\nfrom tinker_cookbook.weights._merge import (\n    MergeOp,\n    MergeProfile,\n    apply_merge_op,\n    detect_merge_profile,\n    merge_adapter_weights,\n    merge_lora_matrices,\n    plan_merge_ops,\n    validate_merge_op_shapes,\n)\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _create_model(path: Path, config: dict, state_dict: dict) -> None:\n    path.mkdir(parents=True)\n    (path / \"config.json\").write_text(json.dumps(config))\n    save_file(state_dict, str(path / \"model.safetensors\"))\n    (path / \"tokenizer_config.json\").write_text(\n        json.dumps({\"tokenizer_class\": \"PreTrainedTokenizerFast\"})\n    )\n    (path / \"tokenizer.json\").write_text(\n        json.dumps(\n            {\n                \"version\": \"1.0\",\n                \"model\": {\"type\": \"BPE\", \"vocab\": {\"a\": 0, \"b\": 1}, \"merges\": []},\n                \"added_tokens\": [],\n            }\n        )\n    )\n\n\ndef _create_adapter(path: Path, weights: dict, config: dict) -> None:\n    path.mkdir(parents=True)\n    save_file(weights, str(path / \"adapter_model.safetensors\"))\n    (path / \"adapter_config.json\").write_text(json.dumps(config))\n\n\n# ---------------------------------------------------------------------------\n# Numerical correctness: verify exact LoRA delta values\n# ---------------------------------------------------------------------------\n\n\nclass TestNumericalCorrectness:\n    \"\"\"Verify that the LoRA math produces exactly the right values.\"\"\"\n\n    def test_merge_lora_matrices_matches_manual(self):\n        \"\"\"B @ A should match manual computation.\"\"\"\n        lora_A = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])  # (2, 3)\n        lora_B = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])  # (3, 2)\n        result = merge_lora_matrices(lora_A, lora_B)\n        # B @ A = [[1,2,3], [4,5,6], [5,7,9]]\n        expected = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [5.0, 7.0, 9.0]])\n        assert torch.allclose(result, expected)\n\n    def test_scaling_applied_correctly(self):\n        \"\"\"lora_alpha/r scaling should multiply lora_B before merge.\"\"\"\n        keys = {\"w.weight\"}\n        profile = MergeProfile()\n        adapter = {\n            \"base_model.model.w.lora_A.weight\": torch.ones(1, 4),\n            \"base_model.model.w.lora_B.weight\": torch.ones(8, 1),\n        }\n        # alpha=4, r=2 → scaling=2, so B becomes 2*ones\n        ops = plan_merge_ops(adapter, {\"lora_alpha\": 4, \"r\": 2}, keys, profile)\n        op = ops[\"w.weight\"][0]\n        # delta = scaled_B @ A = 2*ones(8,1) @ ones(1,4) = 2*ones(8,4)\n        tensors = {\"w.weight\": torch.zeros(8, 4)}\n        apply_merge_op(tensors, op)\n        assert torch.allclose(tensors[\"w.weight\"], torch.full((8, 4), 2.0))\n\n    def test_higher_rank_lora(self):\n        \"\"\"Rank > 1 LoRA should produce correct delta.\"\"\"\n        rank = 4\n        lora_A = torch.eye(rank, 8)  # (4, 8)\n        lora_B = torch.ones(8, rank)  # (8, 4)\n        delta = merge_lora_matrices(lora_A, lora_B)\n        # ones(8,4) @ eye(4,8) = first 4 cols are ones, rest are zeros\n        assert delta.shape == (8, 8)\n        assert torch.allclose(delta[:, :4], torch.ones(8, 4))\n        assert torch.allclose(delta[:, 4:], torch.zeros(8, 4))\n\n    def test_bfloat16_precision(self):\n        \"\"\"Merge should upcast to float32 then cast back, preserving precision.\"\"\"\n        tensors = {\"w\": torch.zeros(4, 4, dtype=torch.bfloat16)}\n        op = MergeOp(\n            target_key=\"w\",\n            lora_A=torch.ones(1, 4, dtype=torch.float32) * 0.001,\n            lora_B=torch.ones(4, 1, dtype=torch.float32),\n        )\n        apply_merge_op(tensors, op)\n        # Result should be close to 0.001 (bfloat16 can represent this approximately)\n        assert tensors[\"w\"].dtype == torch.bfloat16\n        assert tensors[\"w\"].float().mean().item() == pytest.approx(0.001, abs=1e-4)\n\n\n# ---------------------------------------------------------------------------\n# Cross-path consistency: full vs shard should produce identical output\n# (using synthetic models that work with both paths)\n# ---------------------------------------------------------------------------\n\n\nclass TestCrossPathConsistency:\n    \"\"\"Verify plan_merge_ops + apply_merge_op matches merge_adapter_weights.\"\"\"\n\n    def _make_model_and_adapter(self):\n        \"\"\"Create synthetic model state dict and adapter weights.\"\"\"\n        state_dict = {\n            \"model.layers.0.self_attn.q_proj.weight\": torch.randn(8, 4),\n            \"model.layers.0.self_attn.k_proj.weight\": torch.randn(8, 4),\n            \"model.layers.0.mlp.gate_proj.weight\": torch.randn(16, 4),\n        }\n        adapter = {\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight\": torch.randn(2, 4),\n            \"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight\": torch.randn(8, 2),\n            \"base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight\": torch.randn(2, 4),\n            \"base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight\": torch.randn(16, 2),\n        }\n        config = {\"lora_alpha\": 8, \"r\": 2}\n        return state_dict, adapter, config\n\n    def test_plan_apply_matches_merge_adapter_weights(self):\n        \"\"\"Plan+apply path should produce bit-identical results to merge_adapter_weights.\"\"\"\n        state_dict, adapter, config = self._make_model_and_adapter()\n\n        # Path 1: merge_adapter_weights (backward-compat wrapper)\n        sd1 = {k: v.clone() for k, v in state_dict.items()}\n        model1 = type(\"Model\", (torch.nn.Module,), {\"state_dict\": lambda self: sd1})()\n        merge_adapter_weights(model1, adapter, config)\n\n        # Path 2: plan + apply (new API)\n        sd2 = {k: v.clone() for k, v in state_dict.items()}\n        profile = detect_merge_profile({\"architectures\": []}, set(sd2.keys()))\n        ops = plan_merge_ops(adapter, config, set(sd2.keys()), profile)\n        for op_list in ops.values():\n            for op in op_list:\n                apply_merge_op(sd2, op)\n\n        # Must be bit-identical\n        for key in state_dict:\n            assert torch.equal(sd1[key], sd2[key]), f\"Mismatch on {key}\"\n\n    def test_multiple_adapters_targeting_same_key(self):\n        \"\"\"Multiple LoRA ops on the same key should all be applied.\"\"\"\n        tensors = {\"w.weight\": torch.zeros(4, 4)}\n        # Two ops targeting the same key\n        ops = {\n            \"w.weight\": [\n                MergeOp(target_key=\"w.weight\", lora_A=torch.ones(1, 4), lora_B=torch.ones(4, 1)),\n                MergeOp(\n                    target_key=\"w.weight\",\n                    lora_A=torch.ones(1, 4) * 2,\n                    lora_B=torch.ones(4, 1),\n                ),\n            ]\n        }\n        for op in ops[\"w.weight\"]:\n            apply_merge_op(tensors, op)\n        # First adds 1.0, second adds 2.0 → total 3.0\n        assert torch.allclose(tensors[\"w.weight\"], torch.full((4, 4), 3.0))\n\n\n# ---------------------------------------------------------------------------\n# Expert merge edge cases\n# ---------------------------------------------------------------------------\n\n\nclass TestExpertEdgeCases:\n    def test_fused_gate_and_up_both_applied(self):\n        \"\"\"Both gate and up ops should be applied to the same fused tensor.\"\"\"\n        n_exp, in_dim, fused_dim = 2, 4, 8\n        tensors = {\"fused\": torch.zeros(n_exp, in_dim, fused_dim)}\n\n        # Gate op (idx=0, concatenated) → first half\n        gate_op = MergeOp(\n            target_key=\"fused\",\n            lora_A=torch.ones(n_exp, 1, in_dim) * 0.1,\n            lora_B=torch.ones(n_exp, fused_dim // 2, 1),\n            is_expert_3d=True,\n            fused_proj_idx=0,\n        )\n        # Up op (idx=1, concatenated) → second half\n        up_op = MergeOp(\n            target_key=\"fused\",\n            lora_A=torch.ones(n_exp, 1, in_dim) * 0.2,\n            lora_B=torch.ones(n_exp, fused_dim // 2, 1),\n            is_expert_3d=True,\n            fused_proj_idx=1,\n        )\n        apply_merge_op(tensors, gate_op)\n        apply_merge_op(tensors, up_op)\n\n        gate_half = tensors[\"fused\"][:, :, : fused_dim // 2]\n        up_half = tensors[\"fused\"][:, :, fused_dim // 2 :]\n        assert torch.allclose(gate_half, torch.full_like(gate_half, 0.1), atol=1e-6)\n        assert torch.allclose(up_half, torch.full_like(up_half, 0.2), atol=1e-6)\n\n    def test_fused_down_proj_no_slicing(self):\n        \"\"\"Down proj (fused_proj_idx=None) should apply to full tensor.\"\"\"\n        n_exp, in_dim, out_dim = 2, 4, 8\n        tensors = {\"down\": torch.zeros(n_exp, in_dim, out_dim)}\n        op = MergeOp(\n            target_key=\"down\",\n            lora_A=torch.ones(n_exp, 1, in_dim) * 0.3,\n            lora_B=torch.ones(n_exp, out_dim, 1),\n            is_expert_3d=True,\n            fused_proj_idx=None,\n        )\n        apply_merge_op(tensors, op)\n        assert torch.allclose(tensors[\"down\"], torch.full_like(tensors[\"down\"], 0.3), atol=1e-6)\n\n\n# ---------------------------------------------------------------------------\n# Shard path edge cases\n# ---------------------------------------------------------------------------\n\n\nclass TestShardPathEdgeCases:\n    def test_adapter_targets_subset_of_model_keys(self, tmp_path: Path):\n        \"\"\"Adapter only touches some model keys; untouched keys should be preserved.\"\"\"\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n\n        original_value = torch.randn(8, 4)\n        _create_model(\n            model_dir,\n            {\"architectures\": [\"Test\"]},\n            {\n                \"model.layers.0.q_proj.weight\": torch.zeros(8, 4),\n                \"model.layers.0.k_proj.weight\": original_value.clone(),\n            },\n        )\n        _create_adapter(\n            adapter_dir,\n            {\n                \"base_model.model.model.layers.0.q_proj.lora_A.weight\": torch.ones(1, 4),\n                \"base_model.model.model.layers.0.q_proj.lora_B.weight\": torch.ones(8, 1),\n            },\n            {\"lora_alpha\": 1, \"r\": 1},\n        )\n\n        build_hf_model(\n            base_model=str(model_dir),\n            adapter_path=str(adapter_dir),\n            output_path=str(output_dir),\n            merge_strategy=\"shard\",\n        )\n\n        out = load_file(str(output_dir / \"model.safetensors\"))\n        # q_proj was merged\n        assert out[\"model.layers.0.q_proj.weight\"].abs().sum() > 0\n        # k_proj should be exactly preserved (bit-for-bit)\n        assert torch.equal(out[\"model.layers.0.k_proj.weight\"], original_value)\n\n    def test_empty_adapter_produces_identical_output(self, tmp_path: Path, caplog):\n        \"\"\"Adapter with no LoRA weights should produce model identical to input.\"\"\"\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n\n        original = torch.randn(4, 4)\n        _create_model(\n            model_dir,\n            {\"architectures\": [\"Test\"]},\n            {\"model.weight\": original.clone()},\n        )\n        # Adapter with no LoRA keys (just some random tensor)\n        _create_adapter(adapter_dir, {\"random_key\": torch.zeros(1)}, {\"lora_alpha\": 1, \"r\": 1})\n\n        with caplog.at_level(logging.WARNING):\n            build_hf_model(\n                base_model=str(model_dir),\n                adapter_path=str(adapter_dir),\n                output_path=str(output_dir),\n                merge_strategy=\"shard\",\n            )\n\n        assert \"No LoRA weights found\" in caplog.text\n        out = load_file(str(output_dir / \"model.safetensors\"))\n        assert torch.equal(out[\"model.weight\"], original)\n\n    def test_output_exists_raises_before_any_work(self, tmp_path: Path):\n        \"\"\"Should raise FileExistsError early, before expensive merge work.\"\"\"\n        model_dir = tmp_path / \"model\"\n        adapter_dir = tmp_path / \"adapter\"\n        output_dir = tmp_path / \"output\"\n\n        _create_model(\n            model_dir,\n            {\"architectures\": [\"Test\"]},\n            {\"model.w.weight\": torch.zeros(4, 4)},\n        )\n        _create_adapter(\n            adapter_dir,\n            {\n                \"base_model.model.model.w.lora_A.weight\": torch.ones(1, 4),\n                \"base_model.model.model.w.lora_B.weight\": torch.ones(4, 1),\n            },\n            {\"lora_alpha\": 1, \"r\": 1},\n        )\n        output_dir.mkdir()  # pre-create to trigger conflict\n\n        with pytest.raises(FileExistsError, match=\"already exists\"):\n            build_hf_model(\n                base_model=str(model_dir),\n                adapter_path=str(adapter_dir),\n                output_path=str(output_dir),\n                merge_strategy=\"shard\",\n            )\n\n\n# ---------------------------------------------------------------------------\n# Shape validation edge cases\n# ---------------------------------------------------------------------------\n\n\nclass TestShapeValidationEdgeCases:\n    def test_shapes_read_without_loading_tensors(self, tmp_path: Path):\n        \"\"\"get_model_state_shapes should be fast and not load actual tensor data.\"\"\"\n        # Create a model with known shapes\n        tensors = {\n            \"a\": torch.zeros(100, 200),\n            \"b\": torch.zeros(3, 4, 5),\n            \"c\": torch.zeros(7),\n        }\n        save_file(tensors, str(tmp_path / \"model.safetensors\"))\n\n        shapes = get_model_state_shapes(tmp_path)\n        assert shapes == {\"a\": (100, 200), \"b\": (3, 4, 5), \"c\": (7,)}\n\n    def test_validate_catches_wrong_rank(self):\n        \"\"\"Shape validation should catch rank-1 LoRA targeting rank-2 weight.\"\"\"\n        ops = {\"w\": [MergeOp(target_key=\"w\", lora_A=torch.ones(1, 4), lora_B=torch.ones(8, 1))]}\n        shapes = {\"w\": (4, 8)}  # transposed from what the LoRA produces\n        with pytest.raises(ValueError, match=\"Shape mismatch\"):\n            validate_merge_op_shapes(ops, shapes)\n\n\n# ---------------------------------------------------------------------------\n# Profile detection edge cases\n# ---------------------------------------------------------------------------\n\n\nclass TestProfileDetectionEdgeCases:\n    def test_gpt_oss_without_fused_experts(self):\n        \"\"\"GPT-OSS model without experts should get separate layout.\"\"\"\n        config: dict = {\"architectures\": [\"GptOssForCausalLM\"]}\n        keys = {\"model.layers.0.self_attn.q_proj.weight\"}\n        profile = detect_merge_profile(config, keys)\n        assert profile.expert_layout == \"separate\"\n        assert (\".attn\", \".self_attn\") in profile.extra_key_remaps\n\n    def test_vision_gpt_oss_combination(self):\n        \"\"\"GPT-OSS vision model should get both attn remap and language_model prefix.\"\"\"\n        config: dict = {\"architectures\": [\"GptOssVisionModel\"]}\n        keys = {\n            \"model.language_model.layers.0.self_attn.q_proj.weight\",\n            \"model.language_model.layers.0.mlp.experts.gate_up_proj\",\n        }\n        profile = detect_merge_profile(config, keys)\n        assert profile.expert_layout == \"fused_interleaved\"\n        assert profile.has_language_model_prefix is True\n        assert (\".attn\", \".self_attn\") in profile.extra_key_remaps\n\n\n# ---------------------------------------------------------------------------\n# Config loading edge cases\n# ---------------------------------------------------------------------------\n\n\nclass TestConfigLoadingEdgeCases:\n    def test_local_dir_without_config_raises(self, tmp_path: Path):\n        \"\"\"Local directory without config.json should raise FileNotFoundError.\"\"\"\n        with pytest.raises(FileNotFoundError, match=r\"config\\.json\"):\n            load_config_dict(tmp_path)\n\n    def test_local_dir_with_config_loads(self, tmp_path: Path):\n        (tmp_path / \"config.json\").write_text('{\"architectures\": [\"TestModel\"]}')\n        config = load_config_dict(tmp_path)\n        assert config[\"architectures\"] == [\"TestModel\"]\n\n\n# ---------------------------------------------------------------------------\n# ShardWriter edge cases\n# ---------------------------------------------------------------------------\n\n\nclass TestShardWriterEdgeCases:\n    def test_single_tensor_larger_than_max_shard(self, tmp_path: Path):\n        \"\"\"A tensor larger than max_shard_size should get its own shard.\"\"\"\n        writer = ShardWriter(tmp_path, max_shard_size=100)\n        # float32 tensor of 1000 elements = 4000 bytes >> 100 byte limit\n        writer.add_tensor(\"big\", torch.zeros(1000))\n        writer.add_tensor(\"small\", torch.zeros(1))\n        weight_map = writer.finalize()\n\n        assert len(set(weight_map.values())) == 2\n        assert weight_map[\"big\"] != weight_map[\"small\"]\n\n    def test_shard_count_correct_after_multiple_flushes(self, tmp_path: Path):\n        writer = ShardWriter(tmp_path, max_shard_size=100)\n        for i in range(5):\n            writer.add_tensor(f\"t{i}\", torch.zeros(100))  # each triggers flush\n        weight_map = writer.finalize()\n        assert len(set(weight_map.values())) == 5\n"
  },
  {
    "path": "tinker_cookbook/xmux/README.md",
    "content": "# xmux - TMUX-based Experiment Launcher\n\nxmux is a tool for launching and managing hierarchical ML experiments using TMUX. It provides an interactive control window for monitoring and managing large numbers of concurrent experiments.\n\n## Key Features\n\n- **Hierarchical Organization**: Session = Sweep, with a control window for management\n- **Smart Grouping**: Group related experiments in the same window as panes\n- **Interactive Control**: Navigate, monitor, and kill experiments from the control window\n- **Smart Naming**: Automatic abbreviation of long experiment names\n- **Multi-line Status Bar**: Clear overview of all running experiments\n\n## Quick Start\n\n```python\nfrom tinker_cookbook.xmux import JobSpec, SwarmConfig, launch_swarm\n\n# Define your experiments\njob_specs = [\n    JobSpec(\n        main_fn=train_model,  # Your training function\n        log_relpath=\"sweep/model1/lr0.001\",\n        entrypoint_config={\"model\": \"bert\", \"lr\": 0.001}\n    ),\n    # ... more experiments\n]\n\n# Launch the swarm\nconfig = SwarmConfig(sweep_name=\"my-lr-sweep\")\nlaunch_swarm(job_specs, config)\n```\n\n## Grouping Experiments\n\nYou can group related experiments into the same window:\n\n```python\n# Group by model type\nJobSpec(\n    main_fn=train_model,\n    log_relpath=\"sweep/bert/lr0.001\",\n    entrypoint_config=config,\n    tmux_window_name=\"bert\",  # Groups all BERT experiments\n    pane_title=\"lr0.001\"      # Shows in the pane\n)\n```\n\n## Using the Control Window\n\nAfter launching, attach to the TMUX session:\n\n```bash\ntmux attach-session -t my-lr-sweep\n```\n\nControl window commands:\n- **0-9**: Jump to window by number\n- **↑↓**: Navigate job list\n- **k**: Kill selected job\n- **K**: Kill entire window group\n- **r**: Refresh status\n- **q**: Quit control window\n\n## Adding to an Existing Experiment\n\nIf you already have an existing session, you can add\nadditional jobs to the experiment by using the same\nsweep name.\n\n## Examples\n\nSee `examples/ml_sweep.py` for complete examples:\n\n```bash\n# Run demo with dry-run to see what would happen\npython examples/ml_sweep.py 1 --dry-run\n\n# Run actual experiments\npython examples/ml_sweep.py 2\n\n# Demo options:\n# 1 - Individual windows (no grouping)\n# 2 - Grouped by model\n# 3 - Mixed grouping strategy\n# 4 - Large scale sweep (72 experiments)\n```\n\n## Tips\n\n1. **Kill entire sweep**: `tmux kill-session -t sweep-name`\n2. **List xmux sessions**: Look for sessions with metadata in `~/experiments/.xmux/`\n3. **Window limit**: Use grouping for large sweeps to avoid too many windows\n4. **Pane limit**: Set `max_panes_per_window` to control pane overflow\n"
  },
  {
    "path": "tinker_cookbook/xmux/__init__.py",
    "content": "\"\"\"xmux - TMUX-based experiment launcher for ML sweeps\"\"\"\n\nfrom .core import JobSpec, SwarmConfig, launch_swarm\n\n__version__ = \"0.1.0\"\n__all__ = [\"JobSpec\", \"SwarmConfig\", \"launch_swarm\"]\n"
  },
  {
    "path": "tinker_cookbook/xmux/control.py",
    "content": "#!/usr/bin/env python\n\"\"\"Control window for xmux - provides interactive interface for managing experiments\"\"\"\n\nimport contextlib\nimport curses\nimport os\nimport subprocess\nimport sys\nimport time\nfrom datetime import datetime\nfrom enum import StrEnum\n\nfrom pydantic import BaseModel\n\n\nclass JobStatus(StrEnum):\n    \"\"\"Job status enumeration\"\"\"\n\n    UNKNOWN = \"unknown\"\n    RUNNING = \"running\"\n    COMPLETED = \"completed\"\n    FAILED = \"failed\"\n\n\nclass PaneJobInfo(BaseModel):\n    \"\"\"Information about a pane job from metadata\"\"\"\n\n    log_relpath: str\n    display_name: str\n\n\nclass WindowJobInfo(BaseModel):\n    \"\"\"Information about a window job from metadata\"\"\"\n\n    window_name: str\n    panes: dict[str, PaneJobInfo]\n\n\nclass SessionMetadata(BaseModel):\n    \"\"\"Session metadata structure\"\"\"\n\n    session_name: str\n    sweep_name: str | None = None\n    total_jobs: int = 0\n    window_groups: dict[str, int] | None = None\n    ungrouped_jobs: int = 0\n    pane_titles: dict[str, list[str]] | None = None\n    job_mapping: dict[str, WindowJobInfo] | None = None\n\n\nclass PaneInfo(BaseModel):\n    \"\"\"Information about a tmux pane\"\"\"\n\n    index: int\n    pid: int | None\n    dead: bool\n\n\nclass JobInfo(BaseModel):\n    \"\"\"Information about a job\"\"\"\n\n    # Required fields (no defaults) must come first\n    window_index: int\n    window_name: str\n    log_relpath: str\n\n    # Optional fields (with defaults) come after\n    pane_index: int | None = None\n    status: JobStatus = JobStatus.UNKNOWN\n    pid: int | None = None\n\n\ndef load_existing_metadata(session_name: str) -> SessionMetadata | None:\n    \"\"\"Load existing session metadata\"\"\"\n    metadata_path = os.path.expanduser(f\"~/experiments/.xmux/{session_name}.json\")\n    if os.path.exists(metadata_path):\n        with open(metadata_path) as f:\n            return SessionMetadata.model_validate_json(f.read())\n    return None\n\n\nclass ControlWindow:\n    \"\"\"Interactive control window for managing xmux sessions\"\"\"\n\n    def __init__(self, session_name: str):\n        self.session_name: str = session_name\n        self.jobs: list[JobInfo] = []\n        self.selected_index: int = 0\n        self.last_refresh: float = time.time()\n        self.start_time: datetime = datetime.now()\n\n        metadata = load_existing_metadata(self.session_name)\n        assert metadata is not None\n        self.metadata: SessionMetadata = metadata\n\n        # Set up debug log file\n        self.debug_log: str = os.path.expanduser(\n            f\"~/experiments/.xmux/{session_name}_control_debug.log\"\n        )\n        os.makedirs(os.path.dirname(self.debug_log), exist_ok=True)\n\n    def debug_print(self, msg: str) -> None:\n        \"\"\"Write debug messages to log file\"\"\"\n        with open(self.debug_log, \"a\") as f:\n            _ = f.write(f\"[{datetime.now().strftime('%H:%M:%S')}] {msg}\\n\")\n\n    def _load_metadata(self) -> SessionMetadata:\n        \"\"\"Load session metadata\"\"\"\n        metadata_path = os.path.expanduser(f\"~/experiments/.xmux/{self.session_name}.json\")\n        if os.path.exists(metadata_path):\n            with open(metadata_path) as f:\n                return SessionMetadata.model_validate_json(f.read())\n        return SessionMetadata(session_name=self.session_name)\n\n    def _get_window_list(self) -> list[tuple[int, str]]:\n        \"\"\"Get list of windows in the session\"\"\"\n        try:\n            result = subprocess.run(\n                [\n                    \"tmux\",\n                    \"list-windows\",\n                    \"-t\",\n                    self.session_name,\n                    \"-F\",\n                    \"#{window_index}:#{window_name}\",\n                ],\n                capture_output=True,\n                text=True,\n                check=True,\n            )\n            windows: list[tuple[int, str]] = []\n            for line in result.stdout.strip().split(\"\\n\"):\n                if line and \":\" in line:\n                    idx, name = line.split(\":\", 1)\n                    if name != \"control\":  # Skip control window\n                        windows.append((int(idx), name))\n            return windows\n        except subprocess.CalledProcessError:\n            return []\n\n    def _get_pane_info(self, window_index: int) -> list[PaneInfo]:\n        \"\"\"Get information about panes in a window\"\"\"\n        try:\n            result = subprocess.run(\n                [\n                    \"tmux\",\n                    \"list-panes\",\n                    \"-t\",\n                    f\"{self.session_name}:{window_index}\",\n                    \"-F\",\n                    \"#{pane_index}:#{pane_pid}:#{pane_dead}\",\n                ],\n                capture_output=True,\n                text=True,\n                check=True,\n            )\n            panes: list[PaneInfo] = []\n            for line in result.stdout.strip().split(\"\\n\"):\n                if line and \":\" in line:\n                    parts = line.split(\":\")\n                    panes.append(\n                        PaneInfo(\n                            index=int(parts[0]),\n                            pid=int(parts[1]) if parts[1] else None,\n                            dead=parts[2] == \"1\",\n                        )\n                    )\n            return panes\n        except subprocess.CalledProcessError:\n            return []\n\n    def _check_job_status(self, job: JobInfo) -> JobStatus:\n        \"\"\"Check the status of a job by looking at its log files\"\"\"\n        log_dir = os.path.expanduser(f\"~/experiments/{job.log_relpath}\")\n\n        # First check if process is still running\n        is_dead = False\n        if job.window_index:\n            panes = self._get_pane_info(job.window_index)\n            if job.pane_index is not None:\n                # Specific pane\n                for pane in panes:\n                    if pane.index == job.pane_index:\n                        is_dead = pane.dead\n                        break\n            else:\n                # Window with single pane\n                if panes:\n                    is_dead = panes[0].dead\n\n        # If still running, return immediately\n        if not is_dead:\n            return JobStatus.RUNNING\n\n        self.debug_print(f\"Looking in {log_dir}\")\n        # Process is dead - check for completion markers\n        completed_path = os.path.join(log_dir, \".completed\")\n        failed_path = os.path.join(log_dir, \".failed\")\n\n        if os.path.exists(completed_path):\n            self.debug_print(f\"Found completed marker: {completed_path}\")\n            return JobStatus.COMPLETED\n        elif os.path.exists(failed_path):\n            self.debug_print(f\"Found failed marker: {failed_path}\")\n            return JobStatus.FAILED\n        else:\n            # No markers found - this means the marker creation failed\n            # or we're checking too soon. Treat as failed.\n            self.debug_print(\n                f\"No markers found in {log_dir}. Files: {os.listdir(log_dir) if os.path.exists(log_dir) else 'DIR NOT FOUND'}\"\n            )\n            return JobStatus.UNKNOWN\n\n    def refresh_jobs(self) -> None:\n        \"\"\"Refresh the list of jobs and their statuses\"\"\"\n        self.jobs = []\n        windows = self._get_window_list()\n\n        # Get job mapping from metadata\n        self.metadata = self._load_metadata()\n        job_mapping = self.metadata.job_mapping or {}\n\n        for window_index, window_name in windows:\n            panes = self._get_pane_info(window_index)\n\n            # Get job info for this window from metadata\n            window_job_info = job_mapping.get(str(window_index))\n            if not window_job_info:\n                continue\n            pane_info = window_job_info.panes\n\n            if len(panes) <= 1:\n                # Single job in window\n                # Get log_relpath from metadata\n                pane_0 = pane_info.get(\"0\")\n                if not pane_0:\n                    self.debug_print(\n                        f\"WARNING: No pane info found for window {window_index} ({window_name})\"\n                    )\n                    # Skip this job if we don't have a valid path\n                    continue\n                log_relpath = pane_0.log_relpath\n                if not log_relpath:\n                    self.debug_print(\n                        f\"WARNING: No log_relpath found for window {window_index} ({window_name})\"\n                    )\n                    # Skip this job if we don't have a valid path\n                    continue\n\n                job = JobInfo(\n                    window_index=window_index,\n                    window_name=window_name,\n                    log_relpath=log_relpath,\n                    pid=panes[0].pid if panes else None,\n                )\n                job.status = self._check_job_status(job)\n                self.jobs.append(job)\n            else:\n                # Multiple jobs in window (grouped)\n                for i, pane in enumerate(panes):\n                    # Get info for this specific pane\n                    pane_job_info = pane_info.get(str(i))\n                    if not pane_job_info:\n                        continue\n                    log_relpath = pane_job_info.log_relpath\n\n                    if not log_relpath:\n                        self.debug_print(\n                            f\"WARNING: No log_relpath found for window {window_index} pane {i} ({window_name})\"\n                        )\n                        # Skip this pane if we don't have a valid path\n                        continue\n\n                    display_name = pane_job_info.display_name or f\"{window_name}[{i}]\"\n\n                    # Use full name with window prefix for grouped panes\n                    full_display_name = f\"{window_name}/{display_name}\"\n\n                    job = JobInfo(\n                        window_index=window_index,\n                        window_name=full_display_name,\n                        log_relpath=log_relpath,\n                        pane_index=pane.index,\n                        pid=pane.pid,\n                    )\n                    job.status = self._check_job_status(job)\n                    self.jobs.append(job)\n\n        self.last_refresh = time.time()\n\n    def draw_header(self, stdscr: curses.window, height: int, width: int) -> None:\n        \"\"\"Draw the header with session info\"\"\"\n        # Title\n        _ = height  # Parameter required for interface consistency\n        sweep_name = self.metadata.sweep_name or self.session_name\n        title = f\"XMUX CONTROL: {sweep_name}\"\n        stdscr.addstr(0, (width - len(title)) // 2, title, curses.A_BOLD | curses.color_pair(1))\n\n        # Stats\n        total = len(self.jobs)\n        running = sum(1 for j in self.jobs if j.status == JobStatus.RUNNING)\n        completed = sum(1 for j in self.jobs if j.status == JobStatus.COMPLETED)\n        failed = sum(1 for j in self.jobs if j.status == JobStatus.FAILED)\n\n        uptime = datetime.now() - self.start_time\n        hours, remainder = divmod(int(uptime.total_seconds()), 3600)\n        minutes, _ = divmod(remainder, 60)\n\n        stats = f\"Jobs: {total} | Running: {running} | Completed: {completed} | Failed: {failed} | Uptime: {hours}h{minutes}m\"\n        stdscr.addstr(2, 2, stats)\n\n        # Separator\n        stdscr.addstr(3, 0, \"=\" * width, curses.color_pair(1))\n\n    def draw_jobs(self, stdscr: curses.window, height: int, width: int) -> None:\n        \"\"\"Draw the job list\"\"\"\n        # Column headers\n        headers = f\"{'Win':>4} {'Name':<30} {'Status':<12} {'Command':<20}\"\n        stdscr.addstr(5, 2, headers, curses.A_BOLD)\n        stdscr.addstr(6, 0, \"-\" * width)\n\n        # Job list (with scrolling)\n        list_start = 7\n        list_height = height - list_start - 6  # Leave room for footer\n\n        # Calculate scroll position\n        if self.selected_index >= list_height:\n            scroll_offset = self.selected_index - list_height + 1\n        else:\n            scroll_offset = 0\n\n        for i, job in enumerate(self.jobs[scroll_offset : scroll_offset + list_height]):\n            y = list_start + i\n            idx = scroll_offset + i\n\n            # Highlight selected\n            attr = curses.A_REVERSE if idx == self.selected_index else 0\n\n            # Status color\n            if job.status == JobStatus.RUNNING:\n                status_attr = curses.color_pair(2)  # Green\n            elif job.status == JobStatus.COMPLETED:\n                status_attr = curses.color_pair(4)  # Cyan\n            elif job.status == JobStatus.FAILED:\n                status_attr = curses.color_pair(3)  # Red\n            else:\n                status_attr = 0\n\n            # Format job info\n            win_str = f\"{job.window_index:>4}\"\n            name_str = job.window_name[:30].ljust(30)\n            status_str = str(job.status.value).upper()[:12].ljust(12)\n\n            # Draw line\n            stdscr.addstr(y, 2, win_str, attr)\n            stdscr.addstr(y, 7, name_str, attr)\n            stdscr.addstr(y, 38, status_str, attr | status_attr)\n\n    def draw_footer(self, stdscr: curses.window, height: int, width: int) -> None:\n        \"\"\"Draw the footer with commands\"\"\"\n        y = height - 4\n        stdscr.addstr(y, 0, \"-\" * width)\n\n        commands = [\n            \"[0] Control window\",\n            \"[1-9] Go to job group window\",\n            \"[k] Kill job\",\n            \"[K] Kill group\",\n            \"[r] Refresh\",\n            \"[q] Detach\",\n            \"[↑↓] Navigate\",\n            \"[Enter] Select window\",\n        ]\n\n        y += 1\n        cmd_str = \" | \".join(commands)\n        stdscr.addstr(y, 2, cmd_str[: width - 4])\n\n        # Status line\n        y += 2\n        status = f\"Last refresh: {int(time.time() - self.last_refresh)}s ago\"\n        stdscr.addstr(y, 2, status)\n\n    def handle_input(self, _stdscr: curses.window, key: int) -> bool:\n        \"\"\"Handle keyboard input\"\"\"\n        if key == ord(\"r\"):\n            self.refresh_jobs()\n\n        elif key == curses.KEY_UP:\n            self.selected_index = max(0, self.selected_index - 1)\n\n        elif key == curses.KEY_DOWN:\n            self.selected_index = min(len(self.jobs) - 1, self.selected_index + 1)\n\n        elif ord(\"0\") <= key <= ord(\"9\"):\n            # Jump to window\n            window_num = key - ord(\"0\")\n            for job in self.jobs:\n                if job.window_index == window_num:\n                    # Switch to window\n                    _ = subprocess.run(\n                        [\"tmux\", \"select-window\", \"-t\", f\"{self.session_name}:{window_num}\"]\n                    )\n                    break\n\n        elif key == ord(\"\\n\") or key == ord(\" \"):\n            # Select job\n            job = self.jobs[self.selected_index]\n            _ = subprocess.run(\n                [\"tmux\", \"select-window\", \"-t\", f\"{self.session_name}:{job.window_index}\"]\n            )\n\n        elif key == ord(\"k\") and self.jobs:\n            # Kill selected job\n            job = self.jobs[self.selected_index]\n            if job.pane_index is not None:\n                # Kill specific pane\n                _ = subprocess.run(\n                    [\n                        \"tmux\",\n                        \"kill-pane\",\n                        \"-t\",\n                        f\"{self.session_name}:{job.window_index}.{job.pane_index}\",\n                    ]\n                )\n            else:\n                # Kill window\n                _ = subprocess.run(\n                    [\"tmux\", \"kill-window\", \"-t\", f\"{self.session_name}:{job.window_index}\"]\n                )\n            self.refresh_jobs()\n\n        elif key == ord(\"K\") and self.jobs:\n            # Kill entire window group\n            job = self.jobs[self.selected_index]\n            _ = subprocess.run(\n                [\"tmux\", \"kill-window\", \"-t\", f\"{self.session_name}:{job.window_index}\"]\n            )\n            self.refresh_jobs()\n\n        elif key == ord(\"q\") or key == 27:  # q or ESC key\n            # Detach from the tmux session but keep control window alive\n            _ = subprocess.run([\"tmux\", \"detach-client\"])\n            # Don't return False - keep the control window running\n\n        return True\n\n    def run(self, stdscr: curses.window) -> None:\n        \"\"\"Main UI loop\"\"\"\n        # Setup colors\n        curses.start_color()\n        curses.init_pair(1, curses.COLOR_CYAN, curses.COLOR_BLACK)  # Header\n        curses.init_pair(2, curses.COLOR_GREEN, curses.COLOR_BLACK)  # Running\n        curses.init_pair(3, curses.COLOR_RED, curses.COLOR_BLACK)  # Failed\n        curses.init_pair(4, curses.COLOR_BLUE, curses.COLOR_BLACK)  # Completed\n\n        # Configure terminal\n        _ = curses.curs_set(0)  # Hide cursor\n        stdscr.nodelay(True)  # Non-blocking input\n        stdscr.timeout(1000)  # Refresh every second\n\n        # Initial refresh\n        self.refresh_jobs()\n\n        while True:\n            try:\n                height, width = stdscr.getmaxyx()\n                stdscr.clear()\n\n                # Draw UI\n                self.draw_header(stdscr, height, width)\n                self.draw_jobs(stdscr, height, width)\n                self.draw_footer(stdscr, height, width)\n\n                stdscr.refresh()\n\n                # Handle input\n                key = stdscr.getch()\n                if key != -1 and not self.handle_input(stdscr, key):\n                    break\n\n                # Auto-refresh every 3 seconds\n                if time.time() - self.last_refresh > 3:\n                    self.refresh_jobs()\n\n            except curses.error:\n                # Handle terminal resize or other curses errors\n                time.sleep(0.1)\n                continue\n            except KeyboardInterrupt:\n                # Ignore Ctrl-C gracefully\n                pass\n\n\ndef main() -> None:\n    \"\"\"Entry point for control window\"\"\"\n    if len(sys.argv) < 2:\n        print(\"Usage: control.py <session_name>\")\n        sys.exit(1)\n\n    session_name = sys.argv[1]\n    control = ControlWindow(session_name)\n\n    with contextlib.suppress(KeyboardInterrupt):\n        curses.wrapper(control.run)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tinker_cookbook/xmux/core.py",
    "content": "\"\"\"Core data structures and main launch function for xmux\"\"\"\n\nimport os\nimport shlex\nimport subprocess\nimport sys\nfrom collections.abc import Callable\n\nimport cloudpickle\nfrom pydantic import BaseModel\nfrom termcolor import colored\n\nfrom .control import PaneJobInfo, SessionMetadata, WindowJobInfo, load_existing_metadata\nfrom .utils import generate_unique_names, get_symbol_path\n\n\nclass JobSpec(BaseModel):\n    \"\"\"Specification for a single job in the swarm\"\"\"\n\n    main_fn: Callable[..., object]  # function to run\n    log_relpath: str  # path to log directory\n    entrypoint_config: object  # argument to pass to main_fn\n    tmux_window_name: str | None = None  # If set, groups jobs together\n\n    def get_window_name(self, default_name: str) -> str:\n        \"\"\"Get the window name for this job\"\"\"\n        return self.tmux_window_name or default_name\n\n\nclass SwarmConfig(BaseModel):\n    \"\"\"Configuration for launching a swarm of jobs\"\"\"\n\n    sweep_name: str  # Becomes tmux session name\n    max_panes_per_window: int = 4  # When grouping, split into multiple windows if needed\n    use_pickle: bool = True  # Whether to use pickle for config serialization\n    dry_run: bool = False  # If set, will create the session but not launch any jobs\n    control_window_cmd: str | None = None  # Custom control window command\n    status_format: str | None = None  # Custom status bar format\n    debug: bool = False  # Whether to run jobs with pdb debugger\n    verbose: bool = False  # Whether to enable verbose logging\n\n    def get_session_name(self) -> str:\n        \"\"\"Get sanitized session name\"\"\"\n        # Sanitize session name for tmux\n        return \"\".join(c if c.isalnum() or c in [\"_\", \"-\"] else \"_\" for c in self.sweep_name)[:50]\n\n\nclass JobConfig(BaseModel):\n    \"\"\"Internal job configuration\"\"\"\n\n    log_relpath: str\n    entrypoint: str\n    entrypoint_config: object\n    num_gpus: int = 0\n\n\ndef _execute_command(cmdlist: list[str], verbose: bool = False) -> subprocess.CompletedProcess[str]:\n    \"\"\"Execute a command and return the result\"\"\"\n    cmd_str = shlex.join(cmdlist)\n    if verbose:\n        print(colored(f\"Executing: {cmd_str}\", \"green\"))\n    return subprocess.run(cmdlist, check=True, stdout=sys.stderr, stderr=sys.stderr, text=True)\n\n\ndef _save_job_config(job_spec: JobSpec, use_pickle: bool, dry_run: bool) -> str:\n    \"\"\"Save job configuration and return the config path\"\"\"\n    # Create log directory\n    log_dir = os.path.expanduser(f\"~/experiments/{job_spec.log_relpath}\")\n    os.makedirs(log_dir, exist_ok=True)\n\n    # Get entrypoint path\n    entrypoint = get_symbol_path(job_spec.main_fn)\n\n    # Create JobConfig\n    job_config = JobConfig(\n        log_relpath=job_spec.log_relpath,\n        entrypoint=entrypoint,\n        entrypoint_config=job_spec.entrypoint_config,\n        num_gpus=0,\n    )\n\n    # Save config\n    if use_pickle:\n        config_path = os.path.join(log_dir, \"job_config.pickle\")\n        if not dry_run:\n            with open(config_path, \"wb\") as f:\n                cloudpickle.dump(job_config, f)\n    else:\n        config_path = os.path.join(log_dir, \"job_config.json\")\n        if not dry_run:\n            with open(config_path, \"w\") as f:\n                _ = f.write(job_config.model_dump_json(indent=2))\n\n    return config_path\n\n\nclass JobCommand(BaseModel):\n    \"\"\"Command to run a job\"\"\"\n\n    command: str\n    env: dict[str, str]\n\n\ndef _get_tmux_env_flags(env: dict[str, str]) -> list[str]:\n    \"\"\"Get the flags to pass to tmux to set the environment\"\"\"\n    return [flag for key, value in env.items() for flag in (\"-e\", f\"{key}={value}\")]\n\n\ndef _create_job_command(config_path: str, debug: bool = False, dry_run: bool = False) -> JobCommand:\n    \"\"\"Create the command to run a job\"\"\"\n    # Use the same Python interpreter that's running xmux\n    log_dir = os.path.dirname(config_path)\n    if debug:\n        job_cmd_base = f\"{sys.executable} -m pdb -c continue -m tinker_cookbook.xmux.run_job {shlex.quote(config_path)}\"\n    else:\n        job_cmd_base = (\n            f\"{sys.executable} -m tinker_cookbook.xmux.run_job {shlex.quote(config_path)}\"\n        )\n\n    if dry_run:\n        job_cmd_base = 'echo \"dry run enabled, sleeping for 10 seconds\" && sleep 10'\n\n    # Wrap with marker creation - tmux will run this in a shell, so we can use shell syntax directly\n    job_cmd_with_markers = (\n        f\"{job_cmd_base} && touch {shlex.quote(os.path.join(log_dir, '.completed'))} || \"\n        f\"touch {shlex.quote(os.path.join(log_dir, '.failed'))}\"\n    )\n    return JobCommand(command=job_cmd_with_markers, env=os.environ.copy())\n\n\ndef _session_exists(session_name: str) -> bool:\n    \"\"\"Check if a tmux session exists\"\"\"\n    result = subprocess.run(\n        [\"tmux\", \"has-session\", \"-t\", session_name], capture_output=True, check=False\n    )\n    return result.returncode == 0\n\n\ndef _get_next_window_index(metadata: SessionMetadata | None) -> int:\n    \"\"\"Get the next available window index for new jobs\"\"\"\n    if not metadata or not metadata.job_mapping:\n        return 1  # Start from 1 (0 is control window)\n\n    # Find highest existing window index\n    max_index = 0\n    for window_idx in metadata.job_mapping:\n        max_index = max(max_index, int(window_idx))\n\n    return max_index + 1\n\n\ndef _merge_metadata(\n    existing_metadata: SessionMetadata | None, new_metadata: SessionMetadata\n) -> SessionMetadata:\n    \"\"\"Merge new job metadata with existing metadata\"\"\"\n    if not existing_metadata:\n        return new_metadata\n\n    # Merge job_mapping\n    merged_job_mapping: dict[str, WindowJobInfo] = {}\n    if existing_metadata.job_mapping:\n        merged_job_mapping.update(existing_metadata.job_mapping)\n    if new_metadata.job_mapping:\n        merged_job_mapping.update(new_metadata.job_mapping)\n\n    # Merge window_groups\n    merged_window_groups: dict[str, int] = {}\n    if existing_metadata.window_groups:\n        merged_window_groups.update(existing_metadata.window_groups)\n    if new_metadata.window_groups:\n        merged_window_groups.update(new_metadata.window_groups)\n\n    # Merge pane_titles\n    merged_pane_titles: dict[str, list[str]] = {}\n    if existing_metadata.pane_titles:\n        merged_pane_titles.update(existing_metadata.pane_titles)\n    if new_metadata.pane_titles:\n        merged_pane_titles.update(new_metadata.pane_titles)\n\n    # Calculate new totals\n    existing_total = existing_metadata.total_jobs\n    new_total = new_metadata.total_jobs\n\n    existing_ungrouped = existing_metadata.ungrouped_jobs\n    new_ungrouped = new_metadata.ungrouped_jobs\n\n    return SessionMetadata(\n        session_name=existing_metadata.session_name or new_metadata.session_name,\n        sweep_name=existing_metadata.sweep_name or new_metadata.sweep_name,\n        total_jobs=existing_total + new_total,\n        window_groups=merged_window_groups if merged_window_groups else None,\n        ungrouped_jobs=existing_ungrouped + new_ungrouped,\n        pane_titles=merged_pane_titles if merged_pane_titles else None,\n        job_mapping=merged_job_mapping if merged_job_mapping else None,\n    )\n\n\ndef _enable_pane_logging(\n    session_name: str,\n    window_index: int,\n    pane_index: int,\n    log_path: str,\n    verbose: bool = False,\n):\n    \"\"\"Enable logging for a specific pane\"\"\"\n    pane_target = f\"{session_name}:{window_index}.{pane_index}\"\n    # -o means \"output\" mode, -a would mean append (but pipe-pane appends by default)\n    _ = _execute_command(\n        [\"tmux\", \"pipe-pane\", \"-t\", pane_target, \"-o\", f\"cat >> {shlex.quote(log_path)}\"],\n        verbose=verbose,\n    )\n\n\ndef _configure_status_bar(session_name: str, sweep_name: str, verbose: bool = False) -> None:\n    \"\"\"Configure a multi-line status bar for the session\"\"\"\n\n    # Status bar content\n    # Line 1: Session info\n    status_left = f\"[{sweep_name}] \"\n    status_right = \"Jobs: #{window_index}/#{session_windows} | #{host} | %H:%M\"\n\n    # Line 2 will be handled by status-format\n\n    commands = [\n        # Enable status bar\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status\", \"on\"],\n        # Make it 2 lines tall\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status\", \"2\"],\n        # Configure first line\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status-left\", status_left],\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status-right\", status_right],\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status-left-length\", \"40\"],\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status-right-length\", \"60\"],\n        # Window list format (shortened names)\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"window-status-format\", \"#I:#W\"],\n        [\n            \"tmux\",\n            \"set-option\",\n            \"-s\",\n            \"-t\",\n            session_name,\n            \"window-status-current-format\",\n            \"#[bold]#I:#W#[nobold]\",\n        ],\n        # Colors\n        [\n            \"tmux\",\n            \"set-option\",\n            \"-s\",\n            \"-t\",\n            session_name,\n            \"status-style\",\n            \"bg=colour235,fg=colour248\",\n        ],\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status-left-style\", \"fg=colour39,bold\"],\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status-right-style\", \"fg=colour248\"],\n        # Window colors\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"window-status-style\", \"fg=colour248\"],\n        [\n            \"tmux\",\n            \"set-option\",\n            \"-s\",\n            \"-t\",\n            session_name,\n            \"window-status-current-style\",\n            \"fg=colour39,bold\",\n        ],\n        # Refresh interval\n        [\"tmux\", \"set-option\", \"-t\", session_name, \"status-interval\", \"5\"],\n    ]\n\n    if verbose:\n        print(colored(f\"Configuring status bar for session '{session_name}'\", \"cyan\"))\n    for cmd in commands:\n        _ = _execute_command(cmd, verbose=verbose)\n\n\ndef launch_swarm(job_specs: list[JobSpec], config: SwarmConfig) -> None:\n    \"\"\"Launch a swarm of experiments with a control window\"\"\"\n    session_name = config.get_session_name()\n\n    # Check if session already exists\n    session_exists = _session_exists(session_name) and not config.dry_run\n    existing_metadata: SessionMetadata | None = None\n    starting_window_index = 1\n\n    if session_exists:\n        response = input(\n            colored(\n                f\"Session '{session_name}' already exists. Add new jobs to existing session? (y/N): \",\n                \"yellow\",\n            )\n        )\n        if response.lower() not in [\"y\", \"yes\"]:\n            print(colored(\"Aborted.\", \"red\"))\n            return\n        if config.verbose:\n            print(\n                colored(\n                    f\"Adding new jobs to existing session '{session_name}'.\",\n                    \"cyan\",\n                )\n            )\n        existing_metadata = load_existing_metadata(session_name)\n        starting_window_index = _get_next_window_index(existing_metadata)\n        if config.verbose:\n            print(colored(f\"Starting new jobs from window index {starting_window_index}\", \"cyan\"))\n    else:\n        if config.verbose:\n            print(colored(f\"Creating new tmux session '{session_name}'\", \"cyan\"))\n\n    # Group jobs by window name\n    window_groups: dict[str, list[JobSpec]] = {}\n    ungrouped_jobs: list[JobSpec] = []\n\n    for job_spec in job_specs:\n        if job_spec.tmux_window_name:\n            if job_spec.tmux_window_name not in window_groups:\n                window_groups[job_spec.tmux_window_name] = []\n            window_groups[job_spec.tmux_window_name].append(job_spec)\n        else:\n            ungrouped_jobs.append(job_spec)\n\n    # Split large groups if they exceed max_panes_per_window\n    final_window_groups: dict[str, list[JobSpec]] = {}\n    for window_name, jobs in window_groups.items():\n        if len(jobs) <= config.max_panes_per_window:\n            final_window_groups[window_name] = jobs\n        else:\n            # Split into multiple windows\n            for i in range(0, len(jobs), config.max_panes_per_window):\n                chunk = jobs[i : i + config.max_panes_per_window]\n                chunk_name = f\"{window_name}-{i // config.max_panes_per_window + 1}\"\n                final_window_groups[chunk_name] = chunk\n\n    ending_window_index = starting_window_index + len(final_window_groups) + len(ungrouped_jobs)\n\n    # Create control window command\n    control_script_path = os.path.join(os.path.dirname(__file__), \"control.py\")\n    # Use sys.executable to use the same Python interpreter that's running this script\n    control_cmd = (\n        config.control_window_cmd or f\"{sys.executable} {control_script_path} {session_name}\"\n    )\n\n    # Create session only if it doesn't exist\n    if not session_exists:\n        # Create session with a placeholder window first\n        _ = _execute_command(\n            [\"tmux\", \"new-session\", \"-d\", \"-s\", session_name, \"-n\", \"placeholder\", \"sleep 1\"],\n            verbose=config.verbose,\n        )\n        _ = _execute_command(\n            [\"tmux\", \"set-option\", \"-t\", session_name, \"key-table\", session_name],\n            verbose=config.verbose,\n        )\n\n        # Set session options immediately after creation\n        # Set remain-on-exit as a global option for this session\n        # This ensures all windows and panes in this session won't close on exit\n        _ = _execute_command(\n            [\"tmux\", \"set-option\", \"-t\", session_name, \"remain-on-exit\", \"on\"],\n            verbose=config.verbose,\n        )\n\n        # Enable mouse support - session-specific setting\n        _ = _execute_command(\n            [\"tmux\", \"set-option\", \"-t\", session_name, \"mouse\", \"on\"], verbose=config.verbose\n        )\n\n        # Add key binding to detach from the session\n        _ = _execute_command(\n            [\"tmux\", \"bind-key\", \"-T\", session_name, \"q\", \"detach-client\"],\n            verbose=config.verbose,\n        )\n\n        # Add key binding to return to control window from any pane\n        # 0 key will switch to the control window (window 0) - no prefix needed\n        _ = _execute_command(\n            [\n                \"tmux\",\n                \"bind-key\",\n                \"-T\",\n                session_name,\n                \"0\",\n                \"select-window\",\n                \"-t\",\n                f\"{session_name}:0\",\n            ],\n            verbose=config.verbose,\n        )\n\n        # Configure status bar\n        _configure_status_bar(session_name, config.sweep_name, verbose=config.verbose)\n\n    # If a new session is added, we need to add bindings for the new windows too\n    for window_index in range(starting_window_index, ending_window_index):\n        # We can only add single digit hotkeys\n        if window_index < 0 or window_index > 9:\n            continue\n        _ = _execute_command(\n            [\n                \"tmux\",\n                \"bind-key\",\n                \"-T\",\n                session_name,\n                str(window_index),\n                \"select-window\",\n                \"-t\",\n                f\"{session_name}:{window_index}\",\n            ],\n            verbose=config.verbose,\n        )\n\n    # Launch grouped jobs\n    window_index = starting_window_index\n    for window_name, jobs in final_window_groups.items():\n        if config.verbose:\n            print(colored(f\"Creating window '{window_name}' with {len(jobs)} panes\", \"blue\"))\n\n        # Create first pane in new window\n        first_job = jobs[0]\n        config_path = _save_job_config(first_job, config.use_pickle, config.dry_run)\n        job_cmd = _create_job_command(config_path, config.debug, config.dry_run)\n\n        _ = _execute_command(\n            [\n                \"tmux\",\n                \"new-window\",\n                \"-t\",\n                f\"{session_name}:{window_index}\",\n            ]\n            + _get_tmux_env_flags(job_cmd.env)\n            + [\n                \"-n\",\n                window_name,\n                job_cmd.command,\n            ],\n            verbose=config.verbose,\n        )\n        _ = _execute_command(\n            [\n                \"tmux\",\n                \"set-option\",\n                \"-t\",\n                f\"{session_name}:{window_index}\",\n                \"remain-on-exit\",\n                \"on\",\n            ],\n            verbose=config.verbose,\n        )\n\n        # Enable logging for the first pane\n        log_path = os.path.expanduser(f\"~/experiments/{first_job.log_relpath}/log.txt\")\n        _enable_pane_logging(session_name, window_index, 0, log_path, config.verbose)\n\n        # Add remaining panes\n        for i, job in enumerate(jobs[1:], 1):\n            config_path = _save_job_config(job, config.use_pickle, config.dry_run)\n            job_cmd: JobCommand = _create_job_command(config_path, config.debug, config.dry_run)\n\n            # Don't specify percentage, let tmux handle it\n            _ = _execute_command(\n                [\n                    \"tmux\",\n                    \"split-window\",\n                    \"-t\",\n                    f\"{session_name}:{window_index}\",\n                    \"-h\",  # Split horizontally for better layout\n                ]\n                + _get_tmux_env_flags(job_cmd.env)\n                + [\n                    job_cmd.command,\n                ],\n                verbose=config.verbose,\n            )\n\n            # Enable logging for this pane\n            log_path = os.path.expanduser(f\"~/experiments/{job.log_relpath}/log.txt\")\n            _enable_pane_logging(session_name, window_index, i, log_path, config.verbose)\n\n        # Even out the panes\n        if len(jobs) > 1:\n            _ = _execute_command(\n                [\"tmux\", \"select-layout\", \"-t\", f\"{session_name}:{window_index}\", \"tiled\"],\n                verbose=config.verbose,\n            )\n\n        window_index += 1\n\n    # Generate smart abbreviated names for all jobs\n    all_log_relpaths = [job.log_relpath for job in job_specs]\n    abbreviated_names = generate_unique_names(all_log_relpaths, max_length=20)\n    name_map = dict(zip(all_log_relpaths, abbreviated_names, strict=True))\n\n    # Launch ungrouped jobs (each in its own window)\n    for job_spec in ungrouped_jobs:\n        # Use smart abbreviated name\n        window_name = name_map[job_spec.log_relpath]\n\n        if config.verbose:\n            print(colored(f\"Creating window '{window_name}' for individual job\", \"blue\"))\n\n        config_path = _save_job_config(job_spec, config.use_pickle, config.dry_run)\n        job_cmd = _create_job_command(config_path, config.debug, config.dry_run)\n\n        _ = _execute_command(\n            [\n                \"tmux\",\n                \"new-window\",\n                \"-t\",\n                f\"{session_name}:{window_index}\",\n            ]\n            + _get_tmux_env_flags(job_cmd.env)\n            + [\n                \"-n\",\n                window_name,\n                job_cmd.command,\n            ],\n            verbose=config.verbose,\n        )\n        _ = _execute_command(\n            [\n                \"tmux\",\n                \"set-option\",\n                \"-t\",\n                f\"{session_name}:{window_index}\",\n                \"remain-on-exit\",\n                \"on\",\n            ],\n            verbose=config.verbose,\n        )\n\n        # Enable logging for this window (only one pane, so pane index is 0)\n        log_path = os.path.expanduser(f\"~/experiments/{job_spec.log_relpath}/log.txt\")\n        _enable_pane_logging(session_name, window_index, 0, log_path, config.verbose)\n\n        window_index += 1\n\n    # Save swarm metadata\n    metadata_path = os.path.expanduser(f\"~/experiments/.xmux/{session_name}.json\")\n    os.makedirs(os.path.dirname(metadata_path), exist_ok=True)\n\n    # Build pane titles mapping using smart naming\n    pane_titles: dict[str, list[str]] = {}\n    for window_name, jobs in final_window_groups.items():\n        # Always generate smart abbreviated names for panes based on log paths\n        job_paths = [job.log_relpath for job in jobs]\n        pane_titles[window_name] = generate_unique_names(job_paths, max_length=15)\n\n    # Build complete job mapping - this is what control.py needs!\n    job_mapping: dict[str, WindowJobInfo] = {}\n\n    # Add grouped jobs\n    window_idx = starting_window_index  # Start from calculated starting index\n    for window_name, jobs in final_window_groups.items():\n        panes: dict[str, PaneJobInfo] = {}\n        for pane_idx, job in enumerate(jobs):\n            display_name = (\n                pane_titles[window_name][pane_idx]\n                if pane_idx < len(pane_titles.get(window_name, []))\n                else f\"pane-{pane_idx}\"\n            )\n            panes[str(pane_idx)] = PaneJobInfo(\n                log_relpath=job.log_relpath,\n                display_name=display_name,\n            )\n        job_mapping[str(window_idx)] = WindowJobInfo(\n            window_name=window_name,\n            panes=panes,\n        )\n        window_idx += 1\n\n    # Add ungrouped jobs\n    for job_spec in ungrouped_jobs:\n        window_name = name_map[job_spec.log_relpath]\n        job_mapping[str(window_idx)] = WindowJobInfo(\n            window_name=window_name,\n            panes={\"0\": PaneJobInfo(log_relpath=job_spec.log_relpath, display_name=window_name)},\n        )\n        window_idx += 1\n\n    new_metadata = SessionMetadata(\n        session_name=session_name,\n        sweep_name=config.sweep_name,\n        total_jobs=len(job_specs),\n        window_groups={k: len(v) for k, v in final_window_groups.items()},\n        ungrouped_jobs=len(ungrouped_jobs),\n        pane_titles=pane_titles,\n        job_mapping=job_mapping,\n    )\n\n    # Merge with existing metadata if session exists\n    if session_exists and existing_metadata:\n        final_metadata = _merge_metadata(existing_metadata, new_metadata)\n    else:\n        final_metadata = new_metadata\n\n    with open(metadata_path, \"w\") as f:\n        _ = f.write(final_metadata.model_dump_json(exclude_none=True, indent=2))\n\n    # Create the control window only for new sessions\n    if not session_exists:\n        if config.verbose:\n            print(colored(\"Creating control window\", \"cyan\"))\n\n        # Kill the placeholder window\n        _ = _execute_command(\n            [\"tmux\", \"kill-window\", \"-t\", f\"{session_name}:0\"], verbose=config.verbose\n        )\n\n        # Create the actual control window at index 0\n        _ = _execute_command(\n            [\n                \"tmux\",\n                \"new-window\",\n                \"-t\",\n                f\"{session_name}:0\",\n                \"-n\",\n                \"control\",\n                control_cmd,\n            ],\n            verbose=config.verbose,\n        )\n\n        # Set remain-on-exit for the control window\n        _ = _execute_command(\n            [\"tmux\", \"set-option\", \"-w\", \"-t\", f\"{session_name}:0\", \"remain-on-exit\", \"on\"],\n            verbose=config.verbose,\n        )\n\n    # Switch to control window\n    _ = _execute_command(\n        [\"tmux\", \"select-window\", \"-t\", f\"{session_name}:0\"], verbose=config.verbose\n    )\n\n    # Print summary\n    print(colored(\"\\n\" + \"=\" * 60, \"green\"))\n    if session_exists:\n        print(\n            colored(f\"Jobs added to existing swarm '{config.sweep_name}'!\", \"green\", attrs=[\"bold\"])\n        )\n        print(colored(f\"Session: {session_name}\", \"green\"))\n        print(colored(f\"New jobs added: {len(job_specs)}\", \"green\"))\n        total_jobs_now = (existing_metadata.total_jobs if existing_metadata else 0) + len(job_specs)\n        print(colored(f\"Total jobs in session: {total_jobs_now}\", \"green\"))\n    else:\n        print(\n            colored(f\"Swarm '{config.sweep_name}' launched successfully!\", \"green\", attrs=[\"bold\"])\n        )\n        print(colored(f\"Session: {session_name}\", \"green\"))\n        print(colored(f\"Total jobs: {len(job_specs)}\", \"green\"))\n    print(colored(f\"New windows created: {window_index - starting_window_index}\", \"green\"))\n    print(colored(\"\\nTo attach to the session:\", \"cyan\"))\n    print(colored(f\"  tmux attach-session -t {session_name}\", \"cyan\", attrs=[\"bold\"]))\n    print(colored(\"\\nTo kill the entire swarm:\", \"yellow\"))\n    print(colored(f\"  tmux kill-session -t {session_name}\", \"yellow\", attrs=[\"bold\"]))\n    print(colored(\"=\" * 60 + \"\\n\", \"green\"))\n"
  },
  {
    "path": "tinker_cookbook/xmux/examples/async_rl_sweep.py",
    "content": "import argparse\nimport os\n\nimport pandas\n\nfrom tinker_cookbook import model_info\nfrom tinker_cookbook.recipes.math_rl.math_env import Gsm8kDatasetBuilder\nfrom tinker_cookbook.rl import train as rl_train\nfrom tinker_cookbook.xmux import JobSpec, SwarmConfig, launch_swarm\n\n\ndef json_already_exists(log_relpath: str) -> bool:\n    metrics_path = os.path.expanduser(f\"~/experiments/{log_relpath}/metrics.jsonl\")\n    if not os.path.exists(metrics_path):\n        return False\n    df = pandas.read_json(metrics_path, lines=True)\n    return len(df) > 0\n\n\ndef build_rl_basic_config(max_steps_off_policy: int, name: str) -> rl_train.Config:\n    model_name = \"meta-llama/Llama-3.1-8B\"\n    renderer_name = model_info.get_recommended_renderer_name(model_name)\n    builder = Gsm8kDatasetBuilder(\n        batch_size=128,\n        group_size=16,\n        renderer_name=renderer_name,\n        model_name_for_tokenizer=model_name,\n    )\n    return rl_train.Config(\n        model_name=model_name,\n        renderer_name=renderer_name,\n        log_path=f\"/tmp/tinker-examples/async_rl_sweep_{name}\",\n        dataset_builder=builder,\n        learning_rate=4e-5,\n        max_tokens=256,\n        eval_every=2,\n        async_config=rl_train.AsyncConfig(\n            max_steps_off_policy=max_steps_off_policy,\n            groups_per_batch=16,\n        )\n        if max_steps_off_policy > 0\n        else None,\n        enable_trace=True,\n    )\n\n\ndef async_rl_sweep():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dry_run\", action=\"store_true\", help=\"If set, perform a dry run (do not launch jobs)\"\n    )\n    parser.add_argument(\"--verbose\", action=\"store_true\", help=\"If set, print verbose output\")\n    args = parser.parse_args()\n\n    log_relpath_base = \"async_rl_sweep\"\n    job_specs = []\n    for max_steps_off_policy in [0, 1, 2, 4, 8, 16]:\n        tmux_window_name = (\n            f\"off_policy_{max_steps_off_policy}\" if max_steps_off_policy > 0 else \"on_policy\"\n        )\n        rl_config = build_rl_basic_config(\n            max_steps_off_policy=max_steps_off_policy,\n            name=tmux_window_name,\n        )\n        log_relpath = os.path.expanduser(f\"~/experiments/{log_relpath_base}/{tmux_window_name}\")\n\n        if json_already_exists(log_relpath):\n            print(f\"Skipping {log_relpath} because it already exists\")\n            continue\n        job_specs.append(\n            JobSpec(\n                main_fn=rl_train.main,\n                log_relpath=log_relpath,\n                entrypoint_config=rl_config,\n                tmux_window_name=tmux_window_name,\n            )\n        )\n\n    if job_specs:\n        print(f\"Launching {len(job_specs)} sweep experiments with xmux\")\n        config = SwarmConfig(\n            sweep_name=log_relpath_base,\n            max_panes_per_window=5,\n            debug=False,\n            dry_run=args.dry_run,\n            verbose=args.verbose,\n        )\n        launch_swarm(job_specs, config)\n    else:\n        print(\"No experiments to launch (all already exist)\")\n\n\nif __name__ == \"__main__\":\n    async_rl_sweep()\n"
  },
  {
    "path": "tinker_cookbook/xmux/examples/fake_train.py",
    "content": "#!/usr/bin/env python\n\"\"\"Fake training script for xmux demos\"\"\"\n\nimport random\nimport time\nfrom typing import Any\n\nfrom pydantic import BaseModel\n\n\nclass Config(BaseModel):\n    duration: int = 60\n    failure_rate: float = 0.2\n    model: str = \"unknown\"\n    lr: float = 0.001\n\n\ndef fake_train_model(config_dict: dict[str, Any]):\n    \"\"\"Simulate a training job with configurable duration and failure rate\"\"\"\n    config = Config.model_validate(config_dict)\n    assert isinstance(config, Config)\n\n    # Determine if this run will fail\n    will_fail = random.random() < config.failure_rate\n\n    print(\"Starting fake training job...\")\n    print(f\"Model: {config.model}\")\n    print(f\"Learning rate: {config.lr}\")\n    print(f\"Duration: {config.duration}s\")\n    print(f\"Config: {Config.model_dump_json(config)}\")\n    print(\"-\" * 50)\n\n    # Simulate training with periodic output\n    start_time = time.time()\n    loss: float = 2.0  # Initialize loss in case loop doesn't execute\n    for epoch in range(1, config.duration // 5 + 1):\n        if (epoch - 1) * 5 >= config.duration:\n            break\n\n        # Simulate loss decreasing over time (with some noise)\n        base_loss = 2.0 * (0.95**epoch)\n        loss = base_loss + random.uniform(-0.1, 0.1)\n\n        elapsed = int(time.time() - start_time)\n        print(f\"[Epoch {epoch:3d}] [{elapsed:3d}s] loss={loss:.4f} lr={config.lr}\")\n\n        # Random events\n        if random.random() < 0.1:\n            print(f\"[Epoch {epoch:3d}] Validation: accuracy={random.uniform(0.7, 0.95):.3f}\")\n\n        # Fail midway if designated to fail\n        if will_fail and epoch > config.duration // 10:\n            print(\"\\nERROR: Training failed due to simulated error!\")\n            print(\"Exception: Fake convergence issue detected\")\n            raise Exception(\"Fake convergence issue detected\")\n\n        time.sleep(5)\n\n    # Success\n    print(\"\\nTraining completed successfully!\")\n    print(f\"Final loss: {loss:.4f}\")\n    print(f\"Total time: {int(time.time() - start_time)}s\")\n    return 0\n\n\ndef main(config: dict[str, Any]):\n    \"\"\"Entry point that xmux will call\"\"\"\n    # For compatibility with how xmux calls this\n    return fake_train_model(config)\n\n\nif __name__ == \"__main__\":\n    # For testing standalone\n    test_config = {\"model\": \"test-model\", \"lr\": 0.01, \"duration\": 30, \"failure_rate\": 0.1}\n    exit(main(test_config))\n"
  },
  {
    "path": "tinker_cookbook/xmux/examples/ml_sweep.py",
    "content": "#!/usr/bin/env python\n\"\"\"Example ML sweep using xmux with different grouping strategies\"\"\"\n\nimport os\nimport random\nimport shutil\nimport sys\n\nfrom tinker_cookbook.xmux import JobSpec, SwarmConfig, launch_swarm\nfrom tinker_cookbook.xmux.examples.fake_train import main as fake_train_model\n\n\ndef demo_individual_windows():\n    \"\"\"Demo: Each experiment gets its own window\"\"\"\n    print(\"\\n\" + \"=\" * 60)\n    print(\"DEMO 1: Individual Windows (no grouping)\")\n    print(\"=\" * 60 + \"\\n\")\n\n    # Simulate a learning rate sweep\n    job_specs = []\n    for i, (model, lr) in enumerate(\n        [\n            (\"small\", 0.001),\n            (\"small\", 0.01),\n            (\"small\", 0.1),\n            (\"medium\", 0.001),\n            (\"medium\", 0.01),\n            (\"medium\", 0.1),\n            (\"large\", 0.001),\n            (\"large\", 0.01),\n            (\"large\", 0.1),\n        ]\n    ):\n        log_relpath = f\"demo/lr-sweep/{model}/lr{lr}\"\n        abspath = os.path.join(os.path.expanduser(\"~/experiments\"), log_relpath)\n        if os.path.exists(abspath):\n            shutil.rmtree(abspath)\n\n        # Make jobs run faster and with varying success rates\n        # First few jobs succeed, middle ones have mixed results, last ones fail\n        if i < 3:\n            failure_rate = 0.0  # First 3 always succeed\n        elif i < 6:\n            failure_rate = 0.5  # Middle 3 have 50% chance\n        else:\n            failure_rate = 1.0  # Last 3 always fail\n\n        job_specs.append(\n            JobSpec(\n                main_fn=fake_train_model,\n                log_relpath=log_relpath,\n                entrypoint_config={\n                    \"model\": model,\n                    \"lr\": lr,\n                    \"duration\": random.randint(5, 15),  # Much faster: 5-15 seconds\n                    \"failure_rate\": failure_rate,\n                },\n            )\n        )\n\n    config = SwarmConfig(sweep_name=\"lr-sweep-individual\", dry_run=\"--dry-run\" in sys.argv)\n\n    launch_swarm(job_specs, config)\n\n\ndef demo_grouped_by_model():\n    \"\"\"Demo: Group experiments by model type\"\"\"\n    print(\"\\n\" + \"=\" * 60)\n    print(\"DEMO 2: Grouped by Model\")\n    print(\"=\" * 60 + \"\\n\")\n\n    job_specs = []\n    for model in [\"bert-base\", \"bert-large\", \"gpt2\", \"t5-small\"]:\n        for lr in [1e-5, 5e-5, 1e-4]:\n            log_relpath = f\"demo/model-groups/{model}/lr{lr}\"\n\n            job_specs.append(\n                JobSpec(\n                    main_fn=fake_train_model,\n                    log_relpath=log_relpath,\n                    entrypoint_config={\n                        \"model\": model,\n                        \"lr\": lr,\n                        \"duration\": random.randint(30, 90),\n                        \"failure_rate\": 0.15,\n                    },\n                    tmux_window_name=model,  # Group by model\n                )\n            )\n\n    config = SwarmConfig(\n        sweep_name=\"model-grouped-sweep\",\n        max_panes_per_window=3,  # Max 3 learning rates per window\n        dry_run=\"--dry-run\" in sys.argv,\n    )\n\n    launch_swarm(job_specs, config)\n\n\ndef demo_mixed_grouping():\n    \"\"\"Demo: Mix of grouped and individual experiments\"\"\"\n    print(\"\\n\" + \"=\" * 60)\n    print(\"DEMO 3: Mixed Grouping Strategy\")\n    print(\"=\" * 60 + \"\\n\")\n\n    job_specs = []\n\n    # Quick experiments - group together\n    for i in range(6):\n        log_relpath = f\"demo/mixed/quick/exp{i}\"\n\n        job_specs.append(\n            JobSpec(\n                main_fn=fake_train_model,\n                log_relpath=log_relpath,\n                entrypoint_config={\n                    \"exp_id\": i,\n                    \"model\": f\"quick-model-{i}\",\n                    \"duration\": random.randint(10, 30),\n                    \"failure_rate\": 0.1,\n                },\n                tmux_window_name=\"quick-exps\",\n            )\n        )\n\n    # Long-running experiments - individual windows\n    for dataset in [\"imagenet\", \"coco\", \"wmt\"]:\n        for size in [\"full\", \"sample\"]:\n            log_relpath = f\"demo/mixed/long/{dataset}-{size}\"\n\n            job_specs.append(\n                JobSpec(\n                    main_fn=fake_train_model,\n                    log_relpath=log_relpath,\n                    entrypoint_config={\n                        \"dataset\": dataset,\n                        \"size\": size,\n                        \"model\": f\"{dataset}-model\",\n                        \"duration\": random.randint(180, 300),\n                        \"failure_rate\": 0.05,  # Lower failure rate for long runs\n                    },\n                    # No tmux_window_name = individual window\n                )\n            )\n\n    config = SwarmConfig(\n        sweep_name=\"mixed-strategy-demo\", max_panes_per_window=4, dry_run=\"--dry-run\" in sys.argv\n    )\n\n    launch_swarm(job_specs, config)\n\n\ndef demo_large_scale():\n    \"\"\"Demo: Large scale sweep with many experiments\"\"\"\n    print(\"\\n\" + \"=\" * 60)\n    print(\"DEMO 4: Large Scale Sweep\")\n    print(\"=\" * 60 + \"\\n\")\n\n    job_specs = []\n\n    # Grid search over many hyperparameters\n    models = [\"model-v1\", \"model-v2\", \"model-v3\"]\n    learning_rates = [1e-5, 5e-5, 1e-4, 5e-4]\n    batch_sizes = [16, 32, 64]\n    optimizers = [\"adam\", \"sgd\"]\n\n    for model in models:\n        for lr in learning_rates:\n            for bs in batch_sizes:\n                for opt in optimizers:\n                    log_relpath = f\"demo/grid/{model}/lr{lr}-bs{bs}-{opt}\"\n\n                    # Group by model and optimizer\n                    window_name = f\"{model}-{opt}\"\n\n                    job_specs.append(\n                        JobSpec(\n                            main_fn=fake_train_model,\n                            log_relpath=log_relpath,\n                            entrypoint_config={\n                                \"model\": model,\n                                \"lr\": lr,\n                                \"batch_size\": bs,\n                                \"optimizer\": opt,\n                                \"duration\": random.randint(60, 180),\n                                \"failure_rate\": 0.1,\n                            },\n                            tmux_window_name=window_name,\n                        )\n                    )\n\n    print(f\"Total experiments: {len(job_specs)}\")\n\n    config = SwarmConfig(\n        sweep_name=\"large-grid-search\", max_panes_per_window=4, dry_run=\"--dry-run\" in sys.argv\n    )\n\n    launch_swarm(job_specs, config)\n\n\ndef demo_real_usage():\n    \"\"\"Demo: How you would use xmux with real training code\"\"\"\n    print(f\"\"\"\n{\"=\" * 60}\nDEMO: Real Usage Pattern\n{\"=\" * 60}\n\nIn real usage, you would import your actual training function:\n\n```python\nfrom my_project.training import train_model\nfrom my_project.config import TrainingConfig\n\njob_specs = []\nfor lr in [1e-4, 5e-4, 1e-3]:\n    config = TrainingConfig(\n        model_name='bert-base',\n        learning_rate=lr,\n        batch_size=32,\n        num_epochs=10\n    )\n\n    job_specs.append(JobSpec(\n        main_fn=train_model,\n        log_relpath=f'experiments/bert/lr{{lr}}',\n        entrypoint_config=config\n    ))\n\nlaunch_swarm(job_specs, SwarmConfig('bert-lr-sweep'))\n```\n\"\"\")\n\n\ndef main():\n    \"\"\"Run demo based on command line argument\"\"\"\n    if len(sys.argv) < 2 or sys.argv[1] not in [\"1\", \"2\", \"3\", \"4\", \"real\", \"all\"]:\n        print(\"\"\"\nUsage: python ml_sweep.py <demo_number> [--dry-run]\n\nDemos:\n  1 - Individual windows (no grouping)\n  2 - Grouped by model\n  3 - Mixed grouping strategy\n  4 - Large scale sweep\n  real - Show real usage pattern\n  all - Run all demos\n\nAdd --dry-run to see what would be executed without running\n\"\"\")\n        sys.exit(1)\n\n    demo = sys.argv[1]\n\n    # Run requested demo(s)\n    if demo == \"1\":\n        demo_individual_windows()\n    elif demo == \"2\":\n        demo_grouped_by_model()\n    elif demo == \"3\":\n        demo_mixed_grouping()\n    elif demo == \"4\":\n        demo_large_scale()\n    elif demo == \"real\":\n        demo_real_usage()\n    elif demo == \"all\":\n        demo_individual_windows()\n        input(\"\\nPress Enter to continue to next demo...\")\n        demo_grouped_by_model()\n        input(\"\\nPress Enter to continue to next demo...\")\n        demo_mixed_grouping()\n        input(\"\\nPress Enter to continue to next demo...\")\n        demo_large_scale()\n        input(\"\\nPress Enter to continue to next demo...\")\n        demo_real_usage()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tinker_cookbook/xmux/run_job.py",
    "content": "\"\"\"Minimal job runner for xmux\"\"\"\n\nimport argparse\nimport asyncio\nimport importlib\nimport inspect\nimport pickle\nfrom collections.abc import Callable\n\nfrom .core import JobConfig\n\n\ndef get_module_member(path_with_colon: str) -> Callable[..., object]:\n    \"\"\"Import a module member from a colon-separated path\"\"\"\n    module_path, member_name = path_with_colon.split(\":\")\n    module = importlib.import_module(module_path)\n    return getattr(module, member_name)\n\n\ndef main() -> None:\n    parser = argparse.ArgumentParser()\n    _ = parser.add_argument(\"config_path\")\n    args = parser.parse_args()\n\n    # Load configuration\n    config: JobConfig\n    config_path: str = str(args.config_path)\n    if config_path.endswith(\".pickle\"):\n        with open(config_path, \"rb\") as f:\n            loaded = pickle.load(f)  # type: ignore[assignment]\n            if not isinstance(loaded, JobConfig):\n                raise ValueError(\"Pickle file does not contain a JobConfig object\")\n            config = loaded\n    elif config_path.endswith(\".json\"):\n        with open(config_path) as f:\n            config = JobConfig.model_validate_json(f.read())\n    else:\n        raise ValueError(f\"Unknown file extension: {config_path}\")\n\n    # Get and run the function\n    function: Callable[..., object] = get_module_member(config.entrypoint)\n    result: object = function(config.entrypoint_config)\n\n    # Handle async functions\n    if inspect.iscoroutine(result):\n        _ = asyncio.run(result)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tinker_cookbook/xmux/utils.py",
    "content": "\"\"\"Utility functions for xmux\"\"\"\n\nimport importlib\nimport inspect\nimport os\nimport sys\nfrom typing import Any\n\n\ndef find_common_prefix(strings: list[str]) -> str:\n    \"\"\"Find the longest common prefix among strings\"\"\"\n    if not strings:\n        return \"\"\n\n    # Sort to compare only first and last\n    sorted_strings = sorted(strings)\n    first, last = sorted_strings[0], sorted_strings[-1]\n\n    common: list[str] = []\n    for i, char in enumerate(first):\n        if i < len(last) and char == last[i]:\n            common.append(char)\n        else:\n            break\n\n    prefix = \"\".join(common)\n    # Only use prefix if it ends at a natural boundary\n    if prefix and not prefix.endswith((\"/\", \"_\", \"-\")):\n        # Find last boundary\n        for sep in [\"/\", \"_\", \"-\"]:\n            if sep in prefix:\n                prefix = prefix[: prefix.rfind(sep) + 1]\n                break\n        else:\n            # No natural boundary, don't use prefix\n            prefix = \"\"\n\n    return prefix\n\n\ndef abbreviate_path(path: str, max_length: int = 20) -> str:\n    \"\"\"Abbreviate a path to fit within max_length\"\"\"\n    # Common abbreviations\n    replacements = {\n        \"learning_rate\": \"lr\",\n        \"batch_size\": \"bs\",\n        \"num_epochs\": \"ep\",\n        \"model\": \"m\",\n        \"experiment\": \"exp\",\n        \"checkpoint\": \"ckpt\",\n        \"validation\": \"val\",\n        \"training\": \"train\",\n    }\n\n    # Apply replacements\n    result = path\n    for long_form, short_form in replacements.items():\n        result = result.replace(long_form, short_form)\n\n    # If still too long, use more aggressive abbreviation\n    if len(result) > max_length:\n        parts = result.split(\"/\")\n        if len(parts) > 1:\n            # Keep last part full, abbreviate others\n            abbreviated_parts: list[str] = []\n            for part in parts[:-1]:\n                if len(part) > 3:\n                    # Take first letter of each word/section\n                    if \"_\" in part:\n                        abbrev = \"\".join(word[0] for word in part.split(\"_\"))\n                    elif \"-\" in part:\n                        abbrev = \"\".join(word[0] for word in part.split(\"-\"))\n                    else:\n                        abbrev = part[:3]\n                    abbreviated_parts.append(abbrev)\n                else:\n                    abbreviated_parts.append(part)\n\n            abbreviated_parts.append(parts[-1])\n            result = \"/\".join(abbreviated_parts)\n\n    # Final truncation if needed\n    if len(result) > max_length:\n        result = \"...\" + result[-(max_length - 3) :]\n\n    return result\n\n\ndef generate_unique_names(paths: list[str], max_length: int = 20) -> list[str]:\n    \"\"\"Generate unique abbreviated names for a list of paths\"\"\"\n    # Find common prefix to remove\n    common_prefix = find_common_prefix(paths)\n\n    # Remove prefix and abbreviate\n    names: list[str] = []\n    seen_names: set[str] = set()\n\n    for path in paths:\n        # Remove common prefix\n        if common_prefix and path.startswith(common_prefix):\n            shortened = path[len(common_prefix) :]\n        else:\n            shortened = path\n\n        # Abbreviate\n        name = abbreviate_path(shortened, max_length)\n\n        # Ensure uniqueness\n        if name in seen_names:\n            # Add a counter\n            counter = 2\n            while f\"{name}-{counter}\" in seen_names:\n                counter += 1\n            name = f\"{name}-{counter}\"\n\n        seen_names.add(name)\n        names.append(name)\n\n    return names\n\n\ndef smart_window_name(\n    log_relpath: str, session_context: list[str] | None = None, max_length: int = 20\n) -> str:\n    \"\"\"Generate a smart window name for a single job\"\"\"\n    if session_context:\n        # Use context to find common patterns\n        all_paths = session_context + [log_relpath]\n        names = generate_unique_names(all_paths, max_length)\n        return names[-1]  # Return name for the new path\n    else:\n        # No context, just abbreviate\n        return abbreviate_path(log_relpath, max_length)\n\n\ndef format_status_bar_windows(window_names: list[str], max_width: int = 80) -> str:\n    \"\"\"Format window names for status bar display\"\"\"\n    # Format: [0:ctrl] [1:name1] [2:name2] ...\n    formatted: list[str] = []\n    current_width = 0\n\n    for i, name in enumerate(window_names):\n        if i == 0:\n            item = \"[0:ctrl]\"\n        else:\n            # Check if this is a grouped window\n            if name.count(\"-\") > 1 and name.split(\"-\")[-1].isdigit():\n                # Grouped window, add indicator\n                base_name = \"-\".join(name.split(\"-\")[:-1])\n                item = f\"[{i}:{base_name}*]\"\n            else:\n                item = f\"[{i}:{name}]\"\n\n        item_width = len(item) + 1  # +1 for space\n\n        if current_width + item_width > max_width and formatted:\n            # Would overflow, stop here\n            remaining = len(window_names) - len(formatted)\n            formatted.append(f\"... +{remaining}\")\n            break\n\n        formatted.append(item)\n        current_width += item_width\n\n    return \" \".join(formatted)\n\n\nclass SymbolPath(str):\n    module: str\n    name: str\n\n    def __new__(cls, module: str, name: str):\n        value = f\"{module}:{name}\"\n        obj = super().__new__(cls, value)\n        obj.module = module\n        obj.name = name\n        return obj\n\n    @classmethod\n    def from_string(cls, value: str):\n        module, name = value.split(\":\")\n        return cls(module, name)\n\n    def __reduce__(self):\n        return self.__class__, (self.module, self.name)\n\n\ndef get_symbol_path(cls_or_fn: Any) -> SymbolPath:\n    \"\"\"\n    Get the full module path and class name in format 'module.path:ClassName'\n\n    Args:\n        cls: The class to serialize\n\n    Returns:\n        str: The full path in format 'module.path:ClassName'\n\n    Example:\n        >>> class MyClass: pass\n        >>> get_class_path(MyClass)\n        '__main__:MyClass'\n\n        >>> from collections import defaultdict\n        >>> get_class_path(defaultdict)\n        'collections:defaultdict'\n    \"\"\"\n    module = cls_or_fn.__module__\n    class_name = cls_or_fn.__name__\n\n    # If it's in __main__, resolve to actual module path.\n    # This happens if one defines classes in the executable.\n    if module != \"__main__\":\n        return SymbolPath(module, class_name)\n\n    # Get the file where the function is defined\n    file_path = os.path.abspath(inspect.getfile(cls_or_fn))\n\n    # Check for monorepo root in sys.path (including the empty string case)\n    # We have a set of candidates (for example for a file in /a/b/c.py, this could be\n    # - a.b.c relative to the monorepo root\n    # - b.c relative to module `a` if `a` is in sys.path\n    # so just return the shortest candidate match\n    relative_candidates = []\n    for path_entry in sys.path:\n        # Convert empty string to current directory\n        actual_path = os.path.abspath(path_entry) if path_entry else os.getcwd()\n\n        # Check if file is within this path\n        if file_path.startswith(actual_path):\n            # Get relative path from this entry\n            rel_path = os.path.relpath(file_path, actual_path)\n            # Convert to module path\n            module_name = os.path.splitext(rel_path)[0].replace(os.path.sep, \".\")\n            relative_candidates.append(module_name)\n\n    # Try importing each candidate and check if it matches the file\n    # Sort by number of dots to prefer shorter paths (preserves historical behavior)\n    # However, skip single-component module names as they're typically not portable\n    # across different execution environments (e.g., remote workers)\n    for candidate in sorted(relative_candidates, key=lambda x: x.count(\".\")):\n        # Skip single-component names unless no other option exists\n        if \".\" not in candidate and len(relative_candidates) > 1:\n            continue\n        try:\n            mod = importlib.import_module(candidate)\n            # Check if the imported module's file matches the original file\n            if os.path.samefile(getattr(mod, \"__file__\", \"\"), file_path):\n                return SymbolPath(candidate, class_name)\n        except Exception:\n            continue\n    raise ValueError(\n        f\"Could not find valid importable module name for function {cls_or_fn} in {file_path}.\"\n    )\n"
  }
]