Repository: huggingface/ml-intern Branch: main Commit: b292d83aa78e Files: 110 Total size: 948.1 KB Directory structure: gitextract_xwbm9csz/ ├── .gitattributes ├── .github/ │ └── workflows/ │ ├── claude-review.yml │ └── claude.yml ├── .gitignore ├── .python-version ├── Dockerfile ├── README.md ├── REVIEW.md ├── agent/ │ ├── README.md │ ├── __init__.py │ ├── config.py │ ├── context_manager/ │ │ ├── __init__.py │ │ └── manager.py │ ├── core/ │ │ ├── __init__.py │ │ ├── agent_loop.py │ │ ├── doom_loop.py │ │ ├── effort_probe.py │ │ ├── hf_router_catalog.py │ │ ├── llm_params.py │ │ ├── model_switcher.py │ │ ├── prompt_caching.py │ │ ├── session.py │ │ ├── session_uploader.py │ │ └── tools.py │ ├── main.py │ ├── prompts/ │ │ ├── system_prompt.yaml │ │ ├── system_prompt_v2.yaml │ │ └── system_prompt_v3.yaml │ ├── tools/ │ │ ├── __init__.py │ │ ├── dataset_tools.py │ │ ├── docs_tools.py │ │ ├── edit_utils.py │ │ ├── github_find_examples.py │ │ ├── github_list_repos.py │ │ ├── github_read_file.py │ │ ├── hf_repo_files_tool.py │ │ ├── hf_repo_git_tool.py │ │ ├── jobs_tool.py │ │ ├── local_tools.py │ │ ├── papers_tool.py │ │ ├── plan_tool.py │ │ ├── private_hf_repo_tools.py │ │ ├── research_tool.py │ │ ├── sandbox_client.py │ │ ├── sandbox_tool.py │ │ ├── types.py │ │ └── utilities.py │ └── utils/ │ ├── __init__.py │ ├── boot_timing.py │ ├── braille.py │ ├── crt_boot.py │ ├── particle_logo.py │ ├── reliability_checks.py │ └── terminal_display.py ├── backend/ │ ├── __init__.py │ ├── dependencies.py │ ├── main.py │ ├── models.py │ ├── routes/ │ │ ├── __init__.py │ │ ├── agent.py │ │ └── auth.py │ ├── session_manager.py │ ├── start.sh │ └── user_quotas.py ├── configs/ │ └── main_agent_config.json ├── frontend/ │ ├── eslint.config.js │ ├── index.html │ ├── package.json │ ├── src/ │ │ ├── App.tsx │ │ ├── components/ │ │ │ ├── Chat/ │ │ │ │ ├── ActivityStatusBar.tsx │ │ │ │ ├── AssistantMessage.tsx │ │ │ │ ├── ChatInput.tsx │ │ │ │ ├── ExpiredBanner.tsx │ │ │ │ ├── MarkdownContent.tsx │ │ │ │ ├── MessageBubble.tsx │ │ │ │ ├── MessageList.tsx │ │ │ │ ├── ThinkingIndicator.tsx │ │ │ │ ├── ToolCallGroup.tsx │ │ │ │ └── UserMessage.tsx │ │ │ ├── ClaudeCapDialog.tsx │ │ │ ├── CodePanel/ │ │ │ │ └── CodePanel.tsx │ │ │ ├── Layout/ │ │ │ │ └── AppLayout.tsx │ │ │ ├── SessionChat.tsx │ │ │ ├── SessionSidebar/ │ │ │ │ └── SessionSidebar.tsx │ │ │ └── WelcomeScreen/ │ │ │ └── WelcomeScreen.tsx │ │ ├── hooks/ │ │ │ ├── useAgentChat.ts │ │ │ ├── useAuth.ts │ │ │ ├── useOrgMembership.ts │ │ │ └── useUserQuota.ts │ │ ├── lib/ │ │ │ ├── backend-message-store.ts │ │ │ ├── chat-message-store.ts │ │ │ ├── convert-llm-messages.ts │ │ │ ├── research-store.ts │ │ │ └── sse-chat-transport.ts │ │ ├── main.tsx │ │ ├── store/ │ │ │ ├── agentStore.ts │ │ │ ├── layoutStore.ts │ │ │ └── sessionStore.ts │ │ ├── theme.ts │ │ ├── types/ │ │ │ ├── agent.ts │ │ │ └── events.ts │ │ ├── utils/ │ │ │ ├── api.ts │ │ │ ├── logProcessor.ts │ │ │ ├── logger.ts │ │ │ └── model.ts │ │ └── vite-env.d.ts │ ├── tsconfig.json │ └── vite.config.ts ├── pyproject.toml └── tests/ └── unit/ └── test_user_quotas.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ *.png filter=lfs diff=lfs merge=lfs -text ================================================ FILE: .github/workflows/claude-review.yml ================================================ name: Claude PR Review on: pull_request: types: [opened, synchronize, ready_for_review] permissions: contents: read pull-requests: write issues: read id-token: write concurrency: group: claude-review-${{ github.event.pull_request.number }} cancel-in-progress: true jobs: review: if: github.event.pull_request.draft == false runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Compose review prompt id: compose run: | { printf 'prompt<> "$GITHUB_OUTPUT" - uses: anthropics/claude-code-action@v1 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} track_progress: true prompt: ${{ steps.compose.outputs.prompt }} ================================================ FILE: .github/workflows/claude.yml ================================================ name: Claude on Mention on: issue_comment: types: [created] pull_request_review_comment: types: [created] pull_request_review: types: [submitted] issues: types: [opened, assigned] permissions: contents: write pull-requests: write issues: write id-token: write jobs: claude: if: | (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 - uses: anthropics/claude-code-action@v1 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} track_progress: true ================================================ FILE: .gitignore ================================================ # Python-generated files __pycache__/ *.py[oc] build/ dist/ wheels/ *.egg-info .pytest_cache/ .mypy_cache/ .tox/ .coverage htmlcov/ .ipynb_checkpoints/ # Virtual environments .venv/ venv/ ENV/ env/ # Environment and Secrets .env .env.local .env.* !.env.example *.local credentials*.json # OS-specific .DS_Store Thumbs.db *.swp # IDE-specific .vscode/ .idea/ .cursor/ .history/ *.sublime-project *.sublime-workspace # Frontend (Node.js) frontend/node_modules/ frontend/dist/ frontend/.cache/ frontend/*.local frontend/.eslintcache frontend/npm-debug.log* frontend/yarn-debug.log* frontend/yarn-error.log* # Docker .docker/ # Eval (stale) eval/ # Project-specific session_logs/ /logs hf-agent-leaderboard/ skills/ .claude/ *.jsonl *.csv # ML / Data data/ datasets/ models/ checkpoint-*/ runs/ wandb/ frontend/tsconfig.tsbuildinfo ================================================ FILE: .python-version ================================================ 3.12 ================================================ FILE: Dockerfile ================================================ # Stage 1: Build frontend FROM node:20-alpine AS frontend-builder WORKDIR /app/frontend COPY frontend/package.json frontend/package-lock.json ./ RUN npm install COPY frontend/ ./ RUN npm run build # Stage 2: Production FROM python:3.12-slim # Install uv directly from official image COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ # Create user with UID 1000 (required for HF Spaces) RUN useradd -m -u 1000 user WORKDIR /app # Install system dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ git \ curl \ && rm -rf /var/lib/apt/lists/* # Copy dependency files COPY pyproject.toml uv.lock ./ # Install dependencies into /app/.venv # Use --frozen to ensure exact versions from uv.lock RUN uv sync --no-dev --frozen # Copy application code COPY agent/ ./agent/ COPY backend/ ./backend/ COPY configs/ ./configs/ # Copy built frontend COPY --from=frontend-builder /app/frontend/dist ./static/ # Create directories and set ownership RUN mkdir -p /app/session_logs && \ chown -R user:user /app # Switch to non-root user USER user # Set environment ENV HOME=/home/user \ PYTHONUNBUFFERED=1 \ PYTHONPATH=/app \ PATH="/app/.venv/bin:$PATH" # Expose port EXPOSE 7860 # Run the application from backend directory WORKDIR /app/backend CMD ["bash", "start.sh"] ================================================ FILE: README.md ================================================

smolagents logo

# ML Intern An ML intern that autonomously researches, writes, and ships good quality ML releated code using the Hugging Face ecosystem — with deep access to docs, papers, datasets, and cloud compute. ## Quick Start ### Installation ```bash git clone git@github.com:huggingface/ml-intern.git cd ml-intern uv sync uv tool install -e . ``` #### That's it. Now `ml-intern` works from any directory: ```bash ml-intern ``` Create a `.env` file in the project root (or export these in your shell): ```bash ANTHROPIC_API_KEY= # if using anthropic models HF_TOKEN= GITHUB_TOKEN= ``` If no `HF_TOKEN` is set, the CLI will prompt you to paste one on first launch. To get a GITHUB_TOKEN follow the tutorial [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-fine-grained-personal-access-token). ### Usage **Interactive mode** (start a chat session): ```bash ml-intern ``` **Headless mode** (single prompt, auto-approve): ```bash ml-intern "fine-tune llama on my dataset" ``` **Options:** ```bash ml-intern --model anthropic/claude-opus-4-6 "your prompt" ml-intern --max-iterations 100 "your prompt" ml-intern --no-stream "your prompt" ``` ## Architecture ### Component Overview ``` ┌─────────────────────────────────────────────────────────────┐ │ User/CLI │ └────────────┬─────────────────────────────────────┬──────────┘ │ Operations │ Events ↓ (user_input, exec_approval, ↑ submission_queue interrupt, compact, ...) event_queue │ │ ↓ │ ┌────────────────────────────────────────────────────┐ │ │ submission_loop (agent_loop.py) │ │ │ ┌──────────────────────────────────────────────┐ │ │ │ │ 1. Receive Operation from queue │ │ │ │ │ 2. Route to handler (run_agent/compact/...) │ │ │ │ └──────────────────────────────────────────────┘ │ │ │ ↓ │ │ │ ┌──────────────────────────────────────────────┐ │ │ │ │ Handlers.run_agent() │ ├──┤ │ │ │ │ │ │ │ ┌────────────────────────────────────────┐ │ │ │ │ │ │ Agentic Loop (max 300 iterations) │ │ │ │ │ │ │ │ │ │ │ │ │ │ ┌──────────────────────────────────┐ │ │ │ │ │ │ │ │ Session │ │ │ │ │ │ │ │ │ ┌────────────────────────────┐ │ │ │ │ │ │ │ │ │ │ ContextManager │ │ │ │ │ │ │ │ │ │ │ • Message history │ │ │ │ │ │ │ │ │ │ │ (litellm.Message[]) │ │ │ │ │ │ │ │ │ │ │ • Auto-compaction (170k) │ │ │ │ │ │ │ │ │ │ │ • Session upload to HF │ │ │ │ │ │ │ │ │ │ └────────────────────────────┘ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ ┌────────────────────────────┐ │ │ │ │ │ │ │ │ │ │ ToolRouter │ │ │ │ │ │ │ │ │ │ │ ├─ HF docs & research │ │ │ │ │ │ │ │ │ │ │ ├─ HF repos, datasets, │ │ │ │ │ │ │ │ │ │ │ │ jobs, papers │ │ │ │ │ │ │ │ │ │ │ ├─ GitHub code search │ │ │ │ │ │ │ │ │ │ │ ├─ Sandbox & local tools │ │ │ │ │ │ │ │ │ │ │ ├─ Planning │ │ │ │ │ │ │ │ │ │ │ └─ MCP server tools │ │ │ │ │ │ │ │ │ │ └────────────────────────────┘ │ │ │ │ │ │ │ │ └──────────────────────────────────┘ │ │ │ │ │ │ │ │ │ │ │ │ │ │ ┌──────────────────────────────────┐ │ │ │ │ │ │ │ │ Doom Loop Detector │ │ │ │ │ │ │ │ │ • Detects repeated tool patterns │ │ │ │ │ │ │ │ │ • Injects corrective prompts │ │ │ │ │ │ │ │ └──────────────────────────────────┘ │ │ │ │ │ │ │ │ │ │ │ │ │ │ Loop: │ │ │ │ │ │ │ 1. LLM call (litellm.acompletion) │ │ │ │ │ │ │ ↓ │ │ │ │ │ │ │ 2. Parse tool_calls[] │ │ │ │ │ │ │ ↓ │ │ │ │ │ │ │ 3. Approval check │ │ │ │ │ │ │ (jobs, sandbox, destructive ops) │ │ │ │ │ │ │ ↓ │ │ │ │ │ │ │ 4. Execute via ToolRouter │ │ │ │ │ │ │ ↓ │ │ │ │ │ │ │ 5. Add results to ContextManager │ │ │ │ │ │ │ ↓ │ │ │ │ │ │ │ 6. Repeat if tool_calls exist │ │ │ │ │ │ └────────────────────────────────────────┘ │ │ │ │ └──────────────────────────────────────────────┘ │ │ └────────────────────────────────────────────────────┴──┘ ``` ### Agentic Loop Flow ``` User Message ↓ [Add to ContextManager] ↓ ╔═══════════════════════════════════════════╗ ║ Iteration Loop (max 300) ║ ║ ║ ║ Get messages + tool specs ║ ║ ↓ ║ ║ litellm.acompletion() ║ ║ ↓ ║ ║ Has tool_calls? ──No──> Done ║ ║ │ ║ ║ Yes ║ ║ ↓ ║ ║ Add assistant msg (with tool_calls) ║ ║ ↓ ║ ║ Doom loop check ║ ║ ↓ ║ ║ For each tool_call: ║ ║ • Needs approval? ──Yes──> Wait for ║ ║ │ user confirm ║ ║ No ║ ║ ↓ ║ ║ • ToolRouter.execute_tool() ║ ║ • Add result to ContextManager ║ ║ ↓ ║ ║ Continue loop ─────────────────┐ ║ ║ ↑ │ ║ ║ └───────────────────────┘ ║ ╚═══════════════════════════════════════════╝ ``` ## Events The agent emits the following events via `event_queue`: - `processing` - Starting to process user input - `ready` - Agent is ready for input - `assistant_chunk` - Streaming token chunk - `assistant_message` - Complete LLM response text - `assistant_stream_end` - Token stream finished - `tool_call` - Tool being called with arguments - `tool_output` - Tool execution result - `tool_log` - Informational tool log message - `tool_state_change` - Tool execution state transition - `approval_required` - Requesting user approval for sensitive operations - `turn_complete` - Agent finished processing - `error` - Error occurred during processing - `interrupted` - Agent was interrupted - `compacted` - Context was compacted - `undo_complete` - Undo operation completed - `shutdown` - Agent shutting down ## Development ### Adding Built-in Tools Edit `agent/core/tools.py`: ```python def create_builtin_tools() -> list[ToolSpec]: return [ ToolSpec( name="your_tool", description="What your tool does", parameters={ "type": "object", "properties": { "param": {"type": "string", "description": "Parameter description"} }, "required": ["param"] }, handler=your_async_handler ), # ... existing tools ] ``` ### Adding MCP Servers Edit `configs/main_agent_config.json`: ```json { "model_name": "anthropic/claude-sonnet-4-5-20250929", "mcpServers": { "your-server-name": { "transport": "http", "url": "https://example.com/mcp", "headers": { "Authorization": "Bearer ${YOUR_TOKEN}" } } } } ``` Note: Environment variables like `${YOUR_TOKEN}` are auto-substituted from `.env`. ================================================ FILE: REVIEW.md ================================================ # Review instructions These rules override the default review guidance. Treat them as the highest-priority instruction block for any review of this repo. If something here contradicts a more generic review habit, follow these. ## Severity levels Every finding carries one of three priority labels: - **P0** — blocks merge. - **P1** — worth fixing, not blocking. - **P2** — informational. Write labels as plain text (`P0`, `P1`, `P2`) in finding headers. Do not use emoji or colored markers. Use judgment on what belongs at which level — this repo does not enumerate P0 cases; read the code and decide. ## Default bias: rigor Reviews gate merges. This is an open-source repo that takes PRs from anyone; the maintainer team is small and relies on the review to catch what they don't have time to verify themselves. **Default bias is rigor, not speed.** When in doubt on a P0-class concern, investigate further before deciding whether to flag — a false negative ships a bug to production, a false positive costs the contributor one round trip. Rigor is not nitpicking. The P1 cap, "do not report" skip list, and verification bar all still apply. Rigor means going deep on a small number of real concerns, not surfacing a large number of shallow ones. Prefer one well-investigated P0 over three speculative P1s. **Hold the line on P0.** If the author pushes back on a P0 finding without a fix that actually addresses the root cause, re-state the concern with added citations. Only accept the pushback if the author points to code or behavior you missed. Do not soften a P0 because the contributor is polite or new to the repo. For P1 and P2: if the author defers or pushes back without fixing, accept it silently — do not re-flag on subsequent commits. P1/P2 are informational; the author may defer to a follow-up issue at their discretion. If Claude and the author repeatedly disagree on the same class of finding, the signal is that REVIEW.md is missing a rule; note it once in the PR summary as `suggest-rule: ` and stop. ## Investigate before posting The depth of your analysis determines the strength of your finding. For any P0-class concern, before writing it up: - Read the relevant callers and callees, not just the diff. Use Read and Grep to open files the diff doesn't touch but the changed code interacts with. - Trace the full chain end-to-end for routing, auth, and agent-loop findings. Cite each hop by `file:line`, not just the suspicious line. - Check whether the codebase already has an established pattern for this kind of change (`grep` for similar call sites, similar tool definitions, similar route guards). If the PR introduces a new approach where an established pattern exists, flag that — divergence from the existing pattern is usually a regression vector even when the new code "works." - Confirm the specific behavior you're claiming. "This breaks X" must be grounded in either the code handling X or a test exercising X, not in inference from naming or structure. A finding you "spotted" by scanning the diff is more likely to be a false positive than a finding you verified by reading the code around it. ## P1 cap Report at most **3** P1 findings per review. If you found more, say "plus N similar items" in the summary. If everything you found is P1 or below, open the summary with "No blocking issues." ## Re-review convergence If this PR has already received a Claude review (there is a prior review comment by the `claude` bot), suppress new P1 findings and post only P0 ones. Do not re-post P1s that were already flagged on earlier commits. If the author pushed a fix for a previously flagged issue, acknowledge it in one line rather than re-flagging. ## Do not report Anything in these paths — skip entirely: - `frontend/node_modules/**`, `**/*.lock`, `uv.lock`, `package-lock.json` - `hf_agent.egg-info/**`, `.ruff_cache/**`, `.pytest_cache/**`, `.venv/**` - `session_logs/**`, `reports/**` - Anything under a `gen/` or `generated/` path Anything speculative — do not post: - "This might be slow" without a concrete complexity claim tied to a specific input size - Hypothetical race conditions without a concrete interleaving ## Dependency PRs For PRs whose diff is only a lockfile bump, a `pyproject.toml` change, or a new dependency, the code rules above don't apply — risks shift to provenance and framing. Every claim in the title or body (CVE IDs, version numbers, behavior fixes) must match what the diff actually does, and any new transitive dep needs justification. A PR that lies in its framing is P0 regardless of whether the code change is safe in isolation. ## Verification bar Every behavior claim in a finding must cite `file:line`. "This breaks X" is not actionable without a line reference. If you cannot cite a line, do not post the finding. ## Summary shape Open the review body with a single-line tally and an explicit merge verdict, on two lines: ``` 2 P0, 3 P1 Verdict: changes requested ``` Valid verdicts: - **Verdict: ready to merge** — no P0 findings, contributor can merge as-is once any CI passes - **Verdict: changes requested** — at least one P0 that must be addressed before merging - **Verdict: needs discussion** — a design-level concern the maintainer should weigh in on before the contributor iterates (use sparingly) If it's a clean review, write `LGTM` followed by `Verdict: ready to merge`. Then a **What I checked** bullet list — one line per major area you examined, regardless of whether you found anything. This gives the maintainer visible coverage at a glance and lets them decide whether to spot-check areas you didn't touch. ================================================ FILE: agent/README.md ================================================ # Agent Async agent loop with LiteLLM. ## Architecture **Queue-based async system:** - Submissions in (user input) → Agent Loop → Events output for possible UI updates - Session maintains state (context + tools) for possible future Context Engineering - Handlers operations like (USER_INPUT, INTERRUPT, COMPACT, UNDO, SHUTDOWN) for possible UI control ## Components | Component | Purpose | Long Term Goal | |-----------|---------|----------------| | **`agent_loop.py`** | Core agentic loop: processes user input, calls LLM via LiteLLM, executes tool calls iteratively until completion, emits events | Support parallel tool execution, streaming responses, and advanced reasoning patterns | | **`session.py`** | Maintains session state and interaction with potential UI (context, config, event queue), handles interrupts, assigns unique session IDs for tracing | Enable plugging in different UIs (CLI, web, API, programmatic etc.) | | **`tools.py`** | `ToolRouter` manages potential built-in tools (e.g. bash, read_file, write_file which are dummy implementations rn) + MCP tools, converts specs to OpenAI format | Be the place for tools that can be used by the agent. All crazy tool design happens here. | | **`context_manager/`** | Manages conversation history, very rudimentary context engineering support | Implement intelligent context engineering to keep the agent on track | | **`config.py`** | Loads JSON config for the agent | Support different configs etc. | | **`main.py`** | Interactive CLI with async queue architecture (submission→agent, agent→events) (simple way to interact with the agent now)| Serve as reference implementation for other UIs (web, API, programmatic) | ================================================ FILE: agent/__init__.py ================================================ """ HF Agent - Main agent module """ import litellm # Global LiteLLM behavior — set once at package import so both CLI and # backend entries share the same config. # drop_params: quietly drop unsupported params rather than raising # suppress_debug_info: hide the noisy "Give Feedback" banner on errors # modify_params: let LiteLLM patch Anthropic's tool-call requirements # (synthesize a dummy tool spec when we call completion on a history # that contains tool_calls but aren't passing `tools=` — happens # during summarization / session seeding). litellm.drop_params = True litellm.suppress_debug_info = True litellm.modify_params = True from agent.core.agent_loop import submission_loop # noqa: E402 __all__ = ["submission_loop"] ================================================ FILE: agent/config.py ================================================ import json import os import re from pathlib import Path from typing import Any, Union from dotenv import load_dotenv # Project root: two levels up from this file (agent/config.py -> project root) _PROJECT_ROOT = Path(__file__).resolve().parent.parent from fastmcp.mcp_config import ( RemoteMCPServer, StdioMCPServer, ) from pydantic import BaseModel # These two are the canonical server config types for MCP servers. MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer] class Config(BaseModel): """Configuration manager""" model_name: str mcpServers: dict[str, MCPServerConfig] = {} save_sessions: bool = True session_dataset_repo: str = "akseljoonas/hf-agent-sessions" auto_save_interval: int = 3 # Save every N user turns (0 = disabled) yolo_mode: bool = False # Auto-approve all tool calls without confirmation max_iterations: int = 300 # Max LLM calls per agent turn (-1 = unlimited) # Permission control parameters confirm_cpu_jobs: bool = True auto_file_upload: bool = False # Reasoning effort *preference* — the ceiling the user wants. The probe # on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high`` # → …) and caches per-model what the provider actually accepted in # ``Session.model_effective_effort``. Default ``max`` because we'd rather # burn tokens thinking than ship a wrong ML recipe; the cascade lands on # whichever level the model supports (``high`` for GPT-5 / HF router, # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off. # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max" reasoning_effort: str | None = "max" def substitute_env_vars(obj: Any) -> Any: """ Recursively substitute environment variables in any data structure. Supports ${VAR_NAME} syntax for required variables and ${VAR_NAME:-default} for optional. """ if isinstance(obj, str): pattern = r"\$\{([^}:]+)(?::(-)?([^}]*))?\}" def replacer(match): var_name = match.group(1) has_default = match.group(2) is not None default_value = match.group(3) if has_default else None env_value = os.environ.get(var_name) if env_value is not None: return env_value elif has_default: return default_value or "" else: raise ValueError( f"Environment variable '{var_name}' is not set. " f"Add it to your .env file." ) return re.sub(pattern, replacer, obj) elif isinstance(obj, dict): return {key: substitute_env_vars(value) for key, value in obj.items()} elif isinstance(obj, list): return [substitute_env_vars(item) for item in obj] return obj def load_config(config_path: str = "config.json") -> Config: """ Load configuration with environment variable substitution. Use ${VAR_NAME} in your JSON for any secret. Automatically loads from .env file. """ # Load .env from project root first (so it works from any directory), # then CWD .env can override if present load_dotenv(_PROJECT_ROOT / ".env") load_dotenv(override=False) with open(config_path, "r") as f: raw_config = json.load(f) config_with_env = substitute_env_vars(raw_config) return Config.model_validate(config_with_env) ================================================ FILE: agent/context_manager/__init__.py ================================================ """ Context manager for handling conversation history """ from agent.context_manager.manager import ContextManager __all__ = ["ContextManager"] ================================================ FILE: agent/context_manager/manager.py ================================================ """ Context management for conversation history """ import logging import os import zoneinfo from datetime import datetime from pathlib import Path from typing import Any import yaml from jinja2 import Template from litellm import Message, acompletion from agent.core.prompt_caching import with_prompt_caching logger = logging.getLogger(__name__) _HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2" _HF_WHOAMI_TIMEOUT = 5 # seconds def _get_hf_username(hf_token: str | None = None) -> str: """Return the HF username for the given token. Uses subprocess + curl to avoid Python HTTP client IPv6 issues that cause 40+ second hangs (httpx/urllib try IPv6 first which times out at OS level before falling back to IPv4 — the "Happy Eyeballs" problem). """ import json import subprocess import time as _t if not hf_token: logger.warning("No hf_token provided, using 'unknown' as username") return "unknown" t0 = _t.monotonic() try: result = subprocess.run( [ "curl", "-s", "-4", # force IPv4 "-m", str(_HF_WHOAMI_TIMEOUT), # max time "-H", f"Authorization: Bearer {hf_token}", _HF_WHOAMI_URL, ], capture_output=True, text=True, timeout=_HF_WHOAMI_TIMEOUT + 2, ) t1 = _t.monotonic() if result.returncode == 0 and result.stdout: data = json.loads(result.stdout) username = data.get("name", "unknown") logger.info(f"HF username resolved to '{username}' in {t1 - t0:.2f}s") return username else: logger.warning( f"curl whoami failed (rc={result.returncode}) in {t1 - t0:.2f}s" ) return "unknown" except Exception as e: t1 = _t.monotonic() logger.warning(f"HF whoami failed in {t1 - t0:.2f}s: {e}") return "unknown" _COMPACT_PROMPT = ( "Please provide a concise summary of the conversation above, focusing on " "key decisions, the 'why' behind the decisions, problems solved, and " "important context needed for developing further. Your summary will be " "given to someone who has never worked on this project before and they " "will be have to be filled in." ) # Used when seeding a brand-new session from prior browser-cached messages. # Here we're writing a note to *ourselves* — so preserve the tool-call trail, # files produced, and planned next steps in first person. Optimized for # continuity, not brevity. _RESTORE_PROMPT = ( "You're about to be restored into a fresh session with no memory of the " "conversation above. Write a first-person note to your future self so " "you can continue right where you left off. Include:\n" " • What the user originally asked for and what progress you've made.\n" " • Every tool you called, with arguments and a one-line result summary.\n" " • Any code, files, scripts, or artifacts you produced (with paths).\n" " • Key decisions and the reasoning behind them.\n" " • What you were planning to do next.\n\n" "Don't be cute. Be specific. This is the only context you'll have." ) async def summarize_messages( messages: list[Message], model_name: str, hf_token: str | None = None, max_tokens: int = 2000, tool_specs: list[dict] | None = None, prompt: str = _COMPACT_PROMPT, ) -> tuple[str, int]: """Run a summarization prompt against a list of messages. ``prompt`` defaults to the compaction prompt (terse, decision-focused). Callers seeding a new session after a restart should pass ``_RESTORE_PROMPT`` instead — it preserves the tool-call trail so the agent can answer follow-up questions about what it did. Returns ``(summary_text, completion_tokens)``. """ from agent.core.llm_params import _resolve_llm_params prompt_messages = list(messages) + [Message(role="user", content=prompt)] llm_params = _resolve_llm_params(model_name, hf_token, reasoning_effort="high") prompt_messages, tool_specs = with_prompt_caching( prompt_messages, tool_specs, llm_params.get("model") ) response = await acompletion( messages=prompt_messages, max_completion_tokens=max_tokens, tools=tool_specs, **llm_params, ) summary = response.choices[0].message.content or "" completion_tokens = response.usage.completion_tokens if response.usage else 0 return summary, completion_tokens class ContextManager: """Manages conversation context and message history for the agent""" def __init__( self, model_max_tokens: int = 180_000, compact_size: float = 0.1, untouched_messages: int = 5, tool_specs: list[dict[str, Any]] | None = None, prompt_file_suffix: str = "system_prompt_v3.yaml", hf_token: str | None = None, local_mode: bool = False, ): self.system_prompt = self._load_system_prompt( tool_specs or [], prompt_file_suffix="system_prompt_v3.yaml", hf_token=hf_token, local_mode=local_mode, ) # The model's real input-token ceiling (from litellm.get_model_info). # Compaction triggers at _COMPACT_THRESHOLD_RATIO below it — see # the compaction_threshold property. self.model_max_tokens = model_max_tokens self.compact_size = int(model_max_tokens * compact_size) # Running count of tokens the last LLM call reported. Drives the # compaction gate; updated in add_message() with each response's # usage.total_tokens. self.running_context_usage = 0 self.untouched_messages = untouched_messages self.items: list[Message] = [Message(role="system", content=self.system_prompt)] def _load_system_prompt( self, tool_specs: list[dict[str, Any]], prompt_file_suffix: str = "system_prompt.yaml", hf_token: str | None = None, local_mode: bool = False, ): """Load and render the system prompt from YAML file with Jinja2""" prompt_file = Path(__file__).parent.parent / "prompts" / f"{prompt_file_suffix}" with open(prompt_file, "r") as f: prompt_data = yaml.safe_load(f) template_str = prompt_data.get("system_prompt", "") # Get current date and time tz = zoneinfo.ZoneInfo("Europe/Paris") now = datetime.now(tz) current_date = now.strftime("%d-%m-%Y") current_time = now.strftime("%H:%M:%S.%f")[:-3] current_timezone = f"{now.strftime('%Z')} (UTC{now.strftime('%z')[:3]}:{now.strftime('%z')[3:]})" # Get HF user info from OAuth token hf_user_info = _get_hf_username(hf_token) template = Template(template_str) static_prompt = template.render( tools=tool_specs, num_tools=len(tool_specs), ) # CLI-specific context for local mode if local_mode: import os cwd = os.getcwd() local_context = ( f"\n\n# CLI / Local mode\n\n" f"You are running as a local CLI tool on the user's machine. " f"There is NO sandbox — bash, read, write, and edit operate directly " f"on the local filesystem.\n\n" f"Working directory: {cwd}\n" f"Use absolute paths or paths relative to the working directory. " f"Do NOT use /app/ paths — that is a sandbox convention that does not apply here.\n" f"The sandbox_create tool is NOT available. Run code directly with bash." ) static_prompt += local_context return ( f"{static_prompt}\n\n" f"[Session context: Date={current_date}, Time={current_time}, " f"Timezone={current_timezone}, User={hf_user_info}, " f"Tools={len(tool_specs)}]" ) def add_message(self, message: Message, token_count: int = None) -> None: """Add a message to the history""" if token_count: self.running_context_usage = token_count self.items.append(message) def get_messages(self) -> list[Message]: """Get all messages for sending to LLM. Patches any dangling tool_calls (assistant messages with tool_calls that have no matching tool-result message) so the LLM API doesn't reject the request. """ self._patch_dangling_tool_calls() return self.items @staticmethod def _normalize_tool_calls(msg: Message) -> None: """Ensure msg.tool_calls contains proper ToolCall objects, not dicts. litellm's Message has validate_assignment=False (Pydantic v2 default), so direct attribute assignment (e.g. inside litellm's streaming handler) can leave raw dicts. Re-assigning via the constructor fixes this. """ from litellm import ChatCompletionMessageToolCall as ToolCall tool_calls = getattr(msg, "tool_calls", None) if not tool_calls: return needs_fix = any(isinstance(tc, dict) for tc in tool_calls) if not needs_fix: return msg.tool_calls = [ tc if not isinstance(tc, dict) else ToolCall(**tc) for tc in tool_calls ] def _patch_dangling_tool_calls(self) -> None: """Add stub tool results for any tool_calls that lack a matching result. Scans backwards to find the last assistant message with tool_calls, which may not be items[-1] if some tool results were already added. """ if not self.items: return # Find the last assistant message with tool_calls assistant_msg = None for i in range(len(self.items) - 1, -1, -1): msg = self.items[i] if getattr(msg, "role", None) == "assistant" and getattr( msg, "tool_calls", None ): assistant_msg = msg break # Stop scanning once we hit a user message — anything before # that belongs to a previous (complete) turn. if getattr(msg, "role", None) == "user": break if not assistant_msg: return self._normalize_tool_calls(assistant_msg) answered_ids = { getattr(m, "tool_call_id", None) for m in self.items if getattr(m, "role", None) == "tool" } for tc in assistant_msg.tool_calls: if tc.id not in answered_ids: self.items.append( Message( role="tool", content="Tool was not executed (interrupted or error).", tool_call_id=tc.id, name=tc.function.name, ) ) def undo_last_turn(self) -> bool: """Remove the last complete turn (user msg + all assistant/tool msgs that follow). Pops from the end until the last user message is removed, keeping the tool_use/tool_result pairing valid. Never removes the system message. Returns True if a user message was found and removed. """ if len(self.items) <= 1: return False while len(self.items) > 1: msg = self.items.pop() if getattr(msg, "role", None) == "user": return True return False def truncate_to_user_message(self, user_message_index: int) -> bool: """Truncate history to just before the Nth user message (0-indexed). Removes that user message and everything after it. System message (index 0) is never removed. Returns True if the target user message was found and removed. """ count = 0 for i, msg in enumerate(self.items): if i == 0: continue # skip system message if getattr(msg, "role", None) == "user": if count == user_message_index: self.items = self.items[:i] return True count += 1 return False # Compaction fires at 90% of model_max_tokens so there's headroom for # the next turn's prompt + response before we actually hit the ceiling. _COMPACT_THRESHOLD_RATIO = 0.9 @property def compaction_threshold(self) -> int: """Token count at which `compact()` kicks in.""" return int(self.model_max_tokens * self._COMPACT_THRESHOLD_RATIO) @property def needs_compaction(self) -> bool: return self.running_context_usage > self.compaction_threshold and bool(self.items) async def compact( self, model_name: str, tool_specs: list[dict] | None = None, hf_token: str | None = None, ) -> None: """Remove old messages to keep history under target size""" if not self.needs_compaction: return system_msg = ( self.items[0] if self.items and self.items[0].role == "system" else None ) # Preserve the first user message (task prompt) — never summarize it first_user_msg = None first_user_idx = 1 for i in range(1, len(self.items)): if getattr(self.items[i], "role", None) == "user": first_user_msg = self.items[i] first_user_idx = i break # Don't summarize a certain number of just-preceding messages # Walk back to find a user message to make sure we keep an assistant -> user -> # assistant general conversation structure idx = len(self.items) - self.untouched_messages while idx > 1 and self.items[idx].role != "user": idx -= 1 recent_messages = self.items[idx:] messages_to_summarize = self.items[first_user_idx + 1:idx] # improbable, messages would have to very long if not messages_to_summarize: return summary, completion_tokens = await summarize_messages( messages_to_summarize, model_name=model_name, hf_token=hf_token, max_tokens=self.compact_size, tool_specs=tool_specs, prompt=_COMPACT_PROMPT, ) summarized_message = Message(role="assistant", content=summary) # Reconstruct: system + first user msg + summary + recent messages head = [system_msg] if system_msg else [] if first_user_msg: head.append(first_user_msg) self.items = head + [summarized_message] + recent_messages # Count the actual post-compact context — system prompt + first user # turn + summary + the preserved tail all contribute, not just the # summary. litellm.token_counter uses the model's real tokenizer. from litellm import token_counter try: self.running_context_usage = token_counter( model=model_name, messages=[m.model_dump() for m in self.items], ) except Exception as e: logger.warning("token_counter failed post-compact (%s); falling back to rough estimate", e) self.running_context_usage = len(self.system_prompt) // 4 + completion_tokens ================================================ FILE: agent/core/__init__.py ================================================ """ Core agent implementation Contains the main agent logic, decision-making, and orchestration """ from agent.core.tools import ToolRouter, ToolSpec, create_builtin_tools __all__ = [ "ToolRouter", "ToolSpec", "create_builtin_tools", ] ================================================ FILE: agent/core/agent_loop.py ================================================ """loop Main agent implementation with integrated tool system and MCP support """ import asyncio import json import logging import os from dataclasses import dataclass from litellm import ChatCompletionMessageToolCall, Message, acompletion from litellm.exceptions import ContextWindowExceededError from agent.config import Config from agent.core.doom_loop import check_for_doom_loop from agent.core.llm_params import _resolve_llm_params from agent.core.prompt_caching import with_prompt_caching from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter from agent.tools.jobs_tool import CPU_FLAVORS logger = logging.getLogger(__name__) ToolCall = ChatCompletionMessageToolCall def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: """ Validate tool arguments structure. Returns: (is_valid, error_message) """ args = tool_args.get("args", {}) # Sometimes LLM passes args as string instead of dict if isinstance(args, str): return ( False, f"Tool call error: 'args' must be a JSON object, not a string. You passed: {repr(args)}", ) if not isinstance(args, dict) and args is not None: return ( False, f"Tool call error: 'args' must be a JSON object. You passed type: {type(args).__name__}", ) return True, None def _needs_approval( tool_name: str, tool_args: dict, config: Config | None = None ) -> bool: """Check if a tool call requires user approval before execution.""" # Yolo mode: skip all approvals if config and config.yolo_mode: return False # If args are malformed, skip approval (validation error will be shown later) args_valid, _ = _validate_tool_args(tool_args) if not args_valid: return False if tool_name == "sandbox_create": return True if tool_name == "hf_jobs": operation = tool_args.get("operation", "") if operation not in ["run", "uv", "scheduled run", "scheduled uv"]: return False # Check if this is a CPU-only job # hardware_flavor is at top level of tool_args, not nested in args hardware_flavor = ( tool_args.get("hardware_flavor") or tool_args.get("flavor") or tool_args.get("hardware") or "cpu-basic" ) is_cpu_job = hardware_flavor in CPU_FLAVORS if is_cpu_job: if config and not config.confirm_cpu_jobs: return False return True return True # Check for file upload operations (hf_private_repos or other tools) if tool_name == "hf_private_repos": operation = tool_args.get("operation", "") if operation == "upload_file": if config and config.auto_file_upload: return False return True # Other operations (create_repo, etc.) always require approval if operation in ["create_repo"]: return True # hf_repo_files: upload (can overwrite) and delete require approval if tool_name == "hf_repo_files": operation = tool_args.get("operation", "") if operation in ["upload", "delete"]: return True # hf_repo_git: destructive operations require approval if tool_name == "hf_repo_git": operation = tool_args.get("operation", "") if operation in [ "delete_branch", "delete_tag", "merge_pr", "create_repo", "update_repo", ]: return True return False # -- LLM retry constants -------------------------------------------------- _MAX_LLM_RETRIES = 3 _LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries def _is_transient_error(error: Exception) -> bool: """Return True for errors that are likely transient and worth retrying.""" err_str = str(error).lower() transient_patterns = [ "timeout", "timed out", "429", "rate limit", "rate_limit", "503", "service unavailable", "502", "bad gateway", "500", "internal server error", "overloaded", "capacity", "connection reset", "connection refused", "connection error", "eof", "broken pipe", ] return any(pattern in err_str for pattern in transient_patterns) def _is_effort_config_error(error: Exception) -> bool: """Catch the two 400s the effort probe also handles — thinking unsupported for this model, or the specific effort level invalid. This is our safety net for the case where ``/effort`` was changed mid-conversation (which clears the probe cache) and the new level doesn't work for the current model. We heal the cache and retry once. """ from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported return _is_thinking_unsupported(error) or _is_invalid_effort(error) async def _heal_effort_and_rebuild_params( session: Session, error: Exception, llm_params: dict, ) -> dict: """Update the session's effort cache based on ``error`` and return new llm_params. Called only when ``_is_effort_config_error(error)`` is True. Two branches: • thinking-unsupported → cache ``None`` for this model, next call strips thinking entirely • invalid-effort → re-run the full cascade probe; the result lands in the cache """ from agent.core.effort_probe import ProbeInconclusive, _is_thinking_unsupported, probe_effort model = session.config.model_name if _is_thinking_unsupported(error): session.model_effective_effort[model] = None logger.info("healed: %s doesn't support thinking — stripped", model) else: try: outcome = await probe_effort( model, session.config.reasoning_effort, session.hf_token, ) session.model_effective_effort[model] = outcome.effective_effort logger.info( "healed: %s effort cascade → %s", model, outcome.effective_effort, ) except ProbeInconclusive: # Transient during healing — strip thinking for safety, next # call will either succeed or surface the real error. session.model_effective_effort[model] = None logger.info("healed: %s probe inconclusive — stripped", model) return _resolve_llm_params( model, session.hf_token, reasoning_effort=session.effective_effort_for(model), ) def _friendly_error_message(error: Exception) -> str | None: """Return a user-friendly message for known error types, or None to fall back to traceback.""" err_str = str(error).lower() if "authentication" in err_str or "unauthorized" in err_str or "invalid x-api-key" in err_str: return ( "Authentication failed — your API key is missing or invalid.\n\n" "To fix this, set the API key for your model provider:\n" " • Anthropic: export ANTHROPIC_API_KEY=sk-...\n" " • OpenAI: export OPENAI_API_KEY=sk-...\n" " • HF Router: export HF_TOKEN=hf_...\n\n" "You can also add it to a .env file in the project root.\n" "To switch models, use the /model command." ) if "insufficient" in err_str and "credit" in err_str: return ( "Insufficient API credits. Please check your account balance " "at your model provider's dashboard." ) if "not supported by provider" in err_str or "no provider supports" in err_str: return ( "The model isn't served by the provider you pinned.\n\n" "Drop the ':' suffix to let the HF router auto-pick a " "provider, or use '/model' (no arg) to see which providers host " "which models." ) if "model_not_found" in err_str or ( "model" in err_str and ("not found" in err_str or "does not exist" in err_str) ): return ( "Model not found. Use '/model' to list suggestions, or paste an " "HF model id like 'MiniMaxAI/MiniMax-M2.7'. Availability is shown " "when you switch." ) return None async def _compact_and_notify(session: Session) -> None: """Run compaction and send event if context was reduced.""" cm = session.context_manager old_usage = cm.running_context_usage logger.debug( "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s", old_usage, cm.model_max_tokens, cm.compaction_threshold, cm.needs_compaction, ) await cm.compact( model_name=session.config.model_name, tool_specs=session.tool_router.get_tool_specs_for_llm(), hf_token=session.hf_token, ) new_usage = cm.running_context_usage if new_usage != old_usage: logger.warning( "Context compacted: %d -> %d tokens (max=%d, %d messages)", old_usage, new_usage, cm.model_max_tokens, len(cm.items), ) await session.send_event( Event( event_type="compacted", data={"old_tokens": old_usage, "new_tokens": new_usage}, ) ) async def _cleanup_on_cancel(session: Session) -> None: """Kill sandbox processes and cancel HF jobs when the user interrupts.""" # Kill active sandbox processes sandbox = getattr(session, "sandbox", None) if sandbox: try: await asyncio.to_thread(sandbox.kill_all) logger.info("Killed sandbox processes on cancel") except Exception as e: logger.warning("Failed to kill sandbox processes: %s", e) # Cancel running HF jobs job_ids = list(session._running_job_ids) if job_ids: from huggingface_hub import HfApi api = HfApi(token=session.hf_token) for job_id in job_ids: try: await asyncio.to_thread(api.cancel_job, job_id=job_id) logger.info("Cancelled HF job %s on interrupt", job_id) except Exception as e: logger.warning("Failed to cancel HF job %s: %s", job_id, e) session._running_job_ids.clear() @dataclass class LLMResult: """Result from an LLM call (streaming or non-streaming).""" content: str | None tool_calls_acc: dict[int, dict] token_count: int finish_reason: str | None async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult: """Call the LLM with streaming, emitting assistant_chunk events.""" response = None _healed_effort = False # one-shot safety net per call messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) for _llm_attempt in range(_MAX_LLM_RETRIES): try: response = await acompletion( messages=messages, tools=tools, tool_choice="auto", stream=True, stream_options={"include_usage": True}, timeout=600, **llm_params, ) break except ContextWindowExceededError: raise except Exception as e: if not _healed_effort and _is_effort_config_error(e): _healed_effort = True llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params) await session.send_event(Event( event_type="tool_log", data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, )) continue if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( "Transient LLM error (attempt %d/%d): %s — retrying in %ds", _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, ) await session.send_event(Event( event_type="tool_log", data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, )) await asyncio.sleep(_delay) continue raise full_content = "" tool_calls_acc: dict[int, dict] = {} token_count = 0 finish_reason = None async for chunk in response: if session.is_cancelled: tool_calls_acc.clear() break choice = chunk.choices[0] if chunk.choices else None if not choice: if hasattr(chunk, "usage") and chunk.usage: token_count = chunk.usage.total_tokens continue delta = choice.delta if choice.finish_reason: finish_reason = choice.finish_reason if delta.content: full_content += delta.content await session.send_event( Event(event_type="assistant_chunk", data={"content": delta.content}) ) if delta.tool_calls: for tc_delta in delta.tool_calls: idx = tc_delta.index if idx not in tool_calls_acc: tool_calls_acc[idx] = { "id": "", "type": "function", "function": {"name": "", "arguments": ""}, } if tc_delta.id: tool_calls_acc[idx]["id"] = tc_delta.id if tc_delta.function: if tc_delta.function.name: tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name if tc_delta.function.arguments: tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments if hasattr(chunk, "usage") and chunk.usage: token_count = chunk.usage.total_tokens return LLMResult( content=full_content or None, tool_calls_acc=tool_calls_acc, token_count=token_count, finish_reason=finish_reason, ) async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) -> LLMResult: """Call the LLM without streaming, emit assistant_message at the end.""" response = None _healed_effort = False messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) for _llm_attempt in range(_MAX_LLM_RETRIES): try: response = await acompletion( messages=messages, tools=tools, tool_choice="auto", stream=False, timeout=600, **llm_params, ) break except ContextWindowExceededError: raise except Exception as e: if not _healed_effort and _is_effort_config_error(e): _healed_effort = True llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params) await session.send_event(Event( event_type="tool_log", data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, )) continue if _llm_attempt < _MAX_LLM_RETRIES - 1 and _is_transient_error(e): _delay = _LLM_RETRY_DELAYS[_llm_attempt] logger.warning( "Transient LLM error (attempt %d/%d): %s — retrying in %ds", _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, ) await session.send_event(Event( event_type="tool_log", data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, )) await asyncio.sleep(_delay) continue raise choice = response.choices[0] message = choice.message content = message.content or None finish_reason = choice.finish_reason token_count = response.usage.total_tokens if response.usage else 0 # Build tool_calls_acc in the same format as streaming tool_calls_acc: dict[int, dict] = {} if message.tool_calls: for idx, tc in enumerate(message.tool_calls): tool_calls_acc[idx] = { "id": tc.id, "type": "function", "function": { "name": tc.function.name, "arguments": tc.function.arguments, }, } # Emit the full message as a single event if content: await session.send_event( Event(event_type="assistant_message", data={"content": content}) ) return LLMResult( content=content, tool_calls_acc=tool_calls_acc, token_count=token_count, finish_reason=finish_reason, ) class Handlers: """Handler functions for each operation type""" @staticmethod async def _abandon_pending_approval(session: Session) -> None: """Cancel pending approval tools when the user continues the conversation. Injects rejection tool-result messages into the LLM context (so the history stays valid) and notifies the frontend that those tools were abandoned. """ tool_calls = session.pending_approval.get("tool_calls", []) for tc in tool_calls: tool_name = tc.function.name abandon_msg = ( "Task abandoned — user continued the conversation without approving." ) # Keep LLM context valid: every tool_call needs a tool result tool_msg = Message( role="tool", content=abandon_msg, tool_call_id=tc.id, name=tool_name, ) session.context_manager.add_message(tool_msg) await session.send_event( Event( event_type="tool_state_change", data={ "tool_call_id": tc.id, "tool": tool_name, "state": "abandoned", }, ) ) session.pending_approval = None logger.info("Abandoned %d pending approval tool(s)", len(tool_calls)) @staticmethod async def run_agent( session: Session, text: str, ) -> str | None: """ Handle user input (like user_input_or_turn in codex.rs:1291) Returns the final assistant response content, if any. """ # Clear any stale cancellation flag from a previous run session.reset_cancel() # If there's a pending approval and the user sent a new message, # abandon the pending tools so the LLM context stays valid. if text and session.pending_approval: await Handlers._abandon_pending_approval(session) # Add user message to history only if there's actual content if text: user_msg = Message(role="user", content=text) session.context_manager.add_message(user_msg) # Send event that we're processing await session.send_event( Event(event_type="processing", data={"message": "Processing user input"}) ) # Agentic loop - continue until model doesn't call tools or max iterations is reached iteration = 0 final_response = None errored = False max_iterations = session.config.max_iterations while max_iterations == -1 or iteration < max_iterations: # ── Cancellation check: before LLM call ── if session.is_cancelled: break # Compact before calling the LLM if context is near the limit await _compact_and_notify(session) # Doom-loop detection: break out of repeated tool call patterns doom_prompt = check_for_doom_loop(session.context_manager.items) if doom_prompt: session.context_manager.add_message( Message(role="user", content=doom_prompt) ) await session.send_event( Event( event_type="tool_log", data={ "tool": "system", "log": "Doom loop detected — injecting corrective prompt", }, ) ) messages = session.context_manager.get_messages() tools = session.tool_router.get_tool_specs_for_llm() try: # ── Call the LLM (streaming or non-streaming) ── # Pull the per-model probed effort from the session cache when # available; fall back to the raw preference for models we # haven't probed yet (e.g. research sub-model). llm_params = _resolve_llm_params( session.config.model_name, session.hf_token, reasoning_effort=session.effective_effort_for(session.config.model_name), ) if session.stream: llm_result = await _call_llm_streaming(session, messages, tools, llm_params) else: llm_result = await _call_llm_non_streaming(session, messages, tools, llm_params) content = llm_result.content tool_calls_acc = llm_result.tool_calls_acc token_count = llm_result.token_count finish_reason = llm_result.finish_reason # If output was truncated, all tool call args are garbage. # Inject a system hint so the LLM retries with smaller content. if finish_reason == "length" and tool_calls_acc: dropped_names = [ tc["function"]["name"] for tc in tool_calls_acc.values() if tc["function"]["name"] ] logger.warning( "Output truncated (finish_reason=length) — dropping tool calls: %s", dropped_names, ) tool_calls_acc.clear() # Tell the agent what happened so it can retry differently truncation_hint = ( "Your previous response was truncated because the output hit the " "token limit. The following tool calls were lost: " f"{dropped_names}. " "IMPORTANT: Do NOT retry with the same large content. Instead:\n" " • For 'write': use bash with cat<<'HEREDOC' to write the file, " "or split into several smaller edit calls.\n" " • For other tools: reduce the size of your arguments or use bash." ) if content: assistant_msg = Message(role="assistant", content=content) session.context_manager.add_message(assistant_msg, token_count) session.context_manager.add_message( Message(role="user", content=f"[SYSTEM: {truncation_hint}]") ) if session.stream: await session.send_event( Event(event_type="assistant_stream_end", data={}) ) await session.send_event( Event( event_type="tool_log", data={"tool": "system", "log": f"Output truncated — retrying with smaller content ({dropped_names})"}, ) ) iteration += 1 continue # retry this iteration # Build tool_calls list from accumulated deltas tool_calls: list[ToolCall] = [] for idx in sorted(tool_calls_acc.keys()): tc_data = tool_calls_acc[idx] tool_calls.append( ToolCall( id=tc_data["id"], type="function", function={ "name": tc_data["function"]["name"], "arguments": tc_data["function"]["arguments"], }, ) ) # Signal end of streaming to the frontend if session.stream: await session.send_event( Event(event_type="assistant_stream_end", data={}) ) # If no tool calls, add assistant message and we're done if not tool_calls: logger.debug( "Agent loop ending: no tool calls. " "finish_reason=%s, token_count=%d, " "usage=%d, model_max_tokens=%d, " "iteration=%d/%d, " "response_text=%s", finish_reason, token_count, session.context_manager.running_context_usage, session.context_manager.model_max_tokens, iteration, max_iterations, (content or "")[:500], ) if content: assistant_msg = Message(role="assistant", content=content) session.context_manager.add_message(assistant_msg, token_count) final_response = content break # Validate tool call args (one json.loads per call, once) # and split into good vs bad good_tools: list[tuple[ToolCall, str, dict]] = [] bad_tools: list[ToolCall] = [] for tc in tool_calls: try: args = json.loads(tc.function.arguments) good_tools.append((tc, tc.function.name, args)) except (json.JSONDecodeError, TypeError, ValueError): logger.warning( "Malformed arguments for tool_call %s (%s) — skipping", tc.id, tc.function.name, ) tc.function.arguments = "{}" bad_tools.append(tc) # Add assistant message with all tool calls to context assistant_msg = Message( role="assistant", content=content, tool_calls=tool_calls, ) session.context_manager.add_message(assistant_msg, token_count) # Add error results for bad tool calls so the LLM # knows what happened and can retry differently for tc in bad_tools: error_msg = ( f"ERROR: Tool call to '{tc.function.name}' had malformed JSON " f"arguments and was NOT executed. Retry with smaller content — " f"for 'write', split into multiple smaller writes using 'edit'." ) session.context_manager.add_message(Message( role="tool", content=error_msg, tool_call_id=tc.id, name=tc.function.name, )) await session.send_event(Event( event_type="tool_call", data={"tool": tc.function.name, "arguments": {}, "tool_call_id": tc.id}, )) await session.send_event(Event( event_type="tool_output", data={"tool": tc.function.name, "tool_call_id": tc.id, "output": error_msg, "success": False}, )) # ── Cancellation check: before tool execution ── if session.is_cancelled: break # Separate good tools into approval-required vs auto-execute approval_required_tools: list[tuple[ToolCall, str, dict]] = [] non_approval_tools: list[tuple[ToolCall, str, dict]] = [] for tc, tool_name, tool_args in good_tools: if _needs_approval(tool_name, tool_args, session.config): approval_required_tools.append((tc, tool_name, tool_args)) else: non_approval_tools.append((tc, tool_name, tool_args)) # Execute non-approval tools (in parallel when possible) if non_approval_tools: # 1. Validate args upfront parsed_tools: list[ tuple[ToolCall, str, dict, bool, str] ] = [] for tc, tool_name, tool_args in non_approval_tools: args_valid, error_msg = _validate_tool_args(tool_args) parsed_tools.append( (tc, tool_name, tool_args, args_valid, error_msg) ) # 2. Send all tool_call events upfront (so frontend shows them all) for tc, tool_name, tool_args, args_valid, _ in parsed_tools: if args_valid: await session.send_event( Event( event_type="tool_call", data={ "tool": tool_name, "arguments": tool_args, "tool_call_id": tc.id, }, ) ) # 3. Execute all valid tools in parallel, cancellable async def _exec_tool( tc: ToolCall, name: str, args: dict, valid: bool, err: str, ) -> tuple[ToolCall, str, dict, str, bool]: if not valid: return (tc, name, args, err, False) out, ok = await session.tool_router.call_tool( name, args, session=session, tool_call_id=tc.id ) return (tc, name, args, out, ok) gather_task = asyncio.ensure_future(asyncio.gather( *[ _exec_tool(tc, name, args, valid, err) for tc, name, args, valid, err in parsed_tools ] )) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( [gather_task, cancel_task], return_when=asyncio.FIRST_COMPLETED, ) if cancel_task in done: gather_task.cancel() try: await gather_task except asyncio.CancelledError: pass # Notify frontend that in-flight tools were cancelled for tc, name, _args, valid, _ in parsed_tools: if valid: await session.send_event(Event( event_type="tool_state_change", data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"}, )) await _cleanup_on_cancel(session) break cancel_task.cancel() results = gather_task.result() # 4. Record results and send outputs (order preserved) for tc, tool_name, tool_args, output, success in results: tool_msg = Message( role="tool", content=output, tool_call_id=tc.id, name=tool_name, ) session.context_manager.add_message(tool_msg) await session.send_event( Event( event_type="tool_output", data={ "tool": tool_name, "tool_call_id": tc.id, "output": output, "success": success, }, ) ) # If there are tools requiring approval, ask for batch approval if approval_required_tools: # Prepare batch approval data tools_data = [] for tc, tool_name, tool_args in approval_required_tools: # Resolve sandbox file paths for hf_jobs scripts so the # frontend can display & edit the actual file content. if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str): from agent.tools.sandbox_tool import resolve_sandbox_script sandbox = getattr(session, "sandbox", None) resolved, _ = await resolve_sandbox_script(sandbox, tool_args["script"]) if resolved: tool_args = {**tool_args, "script": resolved} tools_data.append({ "tool": tool_name, "arguments": tool_args, "tool_call_id": tc.id, }) await session.send_event(Event( event_type="approval_required", data={"tools": tools_data, "count": len(tools_data)}, )) # Store all approval-requiring tools (ToolCall objects for execution) session.pending_approval = { "tool_calls": [tc for tc, _, _ in approval_required_tools], } # Return early - wait for EXEC_APPROVAL operation return None iteration += 1 except ContextWindowExceededError: # Force compact and retry this iteration cm = session.context_manager logger.warning( "ContextWindowExceededError at iteration %d — forcing compaction " "(usage=%d, model_max_tokens=%d, messages=%d)", iteration, cm.running_context_usage, cm.model_max_tokens, len(cm.items), ) cm.running_context_usage = cm.model_max_tokens + 1 await _compact_and_notify(session) continue except Exception as e: import traceback error_msg = _friendly_error_message(e) if error_msg is None: error_msg = str(e) + "\n" + traceback.format_exc() await session.send_event( Event( event_type="error", data={"error": error_msg}, ) ) errored = True break if session.is_cancelled: await _cleanup_on_cancel(session) await session.send_event(Event(event_type="interrupted")) elif not errored: await session.send_event( Event( event_type="turn_complete", data={"history_size": len(session.context_manager.items)}, ) ) # Increment turn counter and check for auto-save session.increment_turn() await session.auto_save_if_needed() return final_response @staticmethod async def undo(session: Session) -> None: """Remove the last complete turn and notify the frontend.""" removed = session.context_manager.undo_last_turn() if not removed: logger.warning("Undo: no user message found to remove") await session.send_event(Event(event_type="undo_complete")) @staticmethod async def exec_approval(session: Session, approvals: list[dict]) -> None: """Handle batch job execution approval""" if not session.pending_approval: await session.send_event( Event( event_type="error", data={"error": "No pending approval to process"}, ) ) return tool_calls = session.pending_approval.get("tool_calls", []) if not tool_calls: await session.send_event( Event( event_type="error", data={"error": "No pending tool calls found"}, ) ) return # Create a map of tool_call_id -> approval decision approval_map = {a["tool_call_id"]: a for a in approvals} for a in approvals: if a.get("edited_script"): logger.info( f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)" ) # Separate approved and rejected tool calls approved_tasks = [] rejected_tasks = [] for tc in tool_calls: tool_name = tc.function.name try: tool_args = json.loads(tc.function.arguments) except (json.JSONDecodeError, TypeError) as e: # Malformed arguments — treat as failed, notify agent logger.warning(f"Malformed tool arguments for {tool_name}: {e}") tool_msg = Message( role="tool", content=f"Malformed arguments: {e}", tool_call_id=tc.id, name=tool_name, ) session.context_manager.add_message(tool_msg) await session.send_event( Event( event_type="tool_output", data={ "tool": tool_name, "tool_call_id": tc.id, "output": f"Malformed arguments: {e}", "success": False, }, ) ) continue approval_decision = approval_map.get(tc.id, {"approved": False}) if approval_decision.get("approved", False): edited_script = approval_decision.get("edited_script") was_edited = False if edited_script and "script" in tool_args: tool_args["script"] = edited_script was_edited = True logger.info(f"Using user-edited script for {tool_name} ({tc.id})") approved_tasks.append((tc, tool_name, tool_args, was_edited)) else: rejected_tasks.append((tc, tool_name, approval_decision)) # Clear pending approval immediately so a page refresh during # execution won't re-show the approval dialog. session.pending_approval = None # Notify frontend of approval decisions immediately (before execution) for tc, tool_name, tool_args, _was_edited in approved_tasks: await session.send_event( Event( event_type="tool_state_change", data={ "tool_call_id": tc.id, "tool": tool_name, "state": "approved", }, ) ) for tc, tool_name, approval_decision in rejected_tasks: await session.send_event( Event( event_type="tool_state_change", data={ "tool_call_id": tc.id, "tool": tool_name, "state": "rejected", }, ) ) # Execute all approved tools concurrently async def execute_tool(tc, tool_name, tool_args, was_edited): """Execute a single tool and return its result. The TraceLog already exists on the frontend (created by approval_required), so we send tool_state_change instead of tool_call to avoid creating a duplicate. """ await session.send_event( Event( event_type="tool_state_change", data={ "tool_call_id": tc.id, "tool": tool_name, "state": "running", }, ) ) output, success = await session.tool_router.call_tool( tool_name, tool_args, session=session, tool_call_id=tc.id ) return (tc, tool_name, output, success, was_edited) # Execute all approved tools concurrently (cancellable) if approved_tasks: gather_task = asyncio.ensure_future(asyncio.gather( *[ execute_tool(tc, tool_name, tool_args, was_edited) for tc, tool_name, tool_args, was_edited in approved_tasks ], return_exceptions=True, )) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( [gather_task, cancel_task], return_when=asyncio.FIRST_COMPLETED, ) if cancel_task in done: gather_task.cancel() try: await gather_task except asyncio.CancelledError: pass # Notify frontend that approved tools were cancelled for tc, tool_name, _args, _was_edited in approved_tasks: await session.send_event(Event( event_type="tool_state_change", data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"}, )) await _cleanup_on_cancel(session) await session.send_event(Event(event_type="interrupted")) session.increment_turn() await session.auto_save_if_needed() return cancel_task.cancel() results = gather_task.result() # Process results and add to context for result in results: if isinstance(result, Exception): # Handle execution error logger.error(f"Tool execution error: {result}") continue tc, tool_name, output, success, was_edited = result if was_edited: output = f"[Note: The user edited the script before execution. The output below reflects the user-modified version, not your original script.]\n\n{output}" # Add tool result to context tool_msg = Message( role="tool", content=output, tool_call_id=tc.id, name=tool_name, ) session.context_manager.add_message(tool_msg) await session.send_event( Event( event_type="tool_output", data={ "tool": tool_name, "tool_call_id": tc.id, "output": output, "success": success, }, ) ) # Process rejected tools for tc, tool_name, approval_decision in rejected_tasks: rejection_msg = "Job execution cancelled by user" user_feedback = approval_decision.get("feedback") if user_feedback: # Ensure feedback is a string and sanitize any problematic characters feedback_str = str(user_feedback).strip() # Remove any control characters that might break JSON parsing feedback_str = "".join( char for char in feedback_str if ord(char) >= 32 or char in "\n\t" ) rejection_msg += f". User feedback: {feedback_str}" # Ensure rejection_msg is a clean string rejection_msg = str(rejection_msg).strip() tool_msg = Message( role="tool", content=rejection_msg, tool_call_id=tc.id, name=tool_name, ) session.context_manager.add_message(tool_msg) await session.send_event( Event( event_type="tool_output", data={ "tool": tool_name, "tool_call_id": tc.id, "output": rejection_msg, "success": False, }, ) ) # Continue agent loop with empty input to process the tool results await Handlers.run_agent(session, "") @staticmethod async def shutdown(session: Session) -> bool: """Handle shutdown (like shutdown in codex.rs:1329)""" # Save session trajectory if enabled (fire-and-forget, returns immediately) if session.config.save_sessions: logger.info("Saving session...") repo_id = session.config.session_dataset_repo _ = session.save_and_upload_detached(repo_id) session.is_running = False await session.send_event(Event(event_type="shutdown")) return True async def process_submission(session: Session, submission) -> bool: """ Process a single submission and return whether to continue running. Returns: bool: True to continue, False to shutdown """ op = submission.operation logger.debug("Received operation: %s", op.op_type.value) if op.op_type == OpType.USER_INPUT: text = op.data.get("text", "") if op.data else "" await Handlers.run_agent(session, text) return True if op.op_type == OpType.COMPACT: await _compact_and_notify(session) return True if op.op_type == OpType.UNDO: await Handlers.undo(session) return True if op.op_type == OpType.EXEC_APPROVAL: approvals = op.data.get("approvals", []) if op.data else [] await Handlers.exec_approval(session, approvals) return True if op.op_type == OpType.SHUTDOWN: return not await Handlers.shutdown(session) logger.warning(f"Unknown operation: {op.op_type}") return True async def submission_loop( submission_queue: asyncio.Queue, event_queue: asyncio.Queue, config: Config | None = None, tool_router: ToolRouter | None = None, session_holder: list | None = None, hf_token: str | None = None, local_mode: bool = False, stream: bool = True, ) -> None: """ Main agent loop - processes submissions and dispatches to handlers. This is the core of the agent (like submission_loop in codex.rs:1259-1340) """ # Create session with tool router session = Session( event_queue, config=config, tool_router=tool_router, hf_token=hf_token, local_mode=local_mode, stream=stream, ) if session_holder is not None: session_holder[0] = session logger.info("Agent loop started") # Retry any failed uploads from previous sessions (fire-and-forget) if config and config.save_sessions: Session.retry_failed_uploads_detached( directory="session_logs", repo_id=config.session_dataset_repo ) try: # Main processing loop async with tool_router: # Emit ready event after initialization await session.send_event( Event(event_type="ready", data={ "message": "Agent initialized", "tool_count": len(tool_router.tools), }) ) while session.is_running: submission = await submission_queue.get() try: should_continue = await process_submission(session, submission) if not should_continue: break except asyncio.CancelledError: logger.warning("Agent loop cancelled") break except Exception as e: logger.error(f"Error in agent loop: {e}") await session.send_event( Event(event_type="error", data={"error": str(e)}) ) logger.info("Agent loop exited") finally: # Emergency save if session saving is enabled and shutdown wasn't called properly if session.config.save_sessions and session.is_running: logger.info("Emergency save: preserving session before exit...") try: local_path = session.save_and_upload_detached( session.config.session_dataset_repo ) if local_path: logger.info("Emergency save successful, upload in progress") except Exception as e: logger.error(f"Emergency save failed: {e}") ================================================ FILE: agent/core/doom_loop.py ================================================ """ Doom-loop detection for repeated tool call patterns. Detects when the agent is stuck calling the same tools repeatedly and injects a corrective prompt to break the cycle. """ import hashlib import json import logging from dataclasses import dataclass from litellm import Message logger = logging.getLogger(__name__) @dataclass(frozen=True) class ToolCallSignature: """Hashable signature for a single tool call (name + args hash).""" name: str args_hash: str def _hash_args(args_str: str) -> str: """Return a short hash of the JSON arguments string.""" return hashlib.md5(args_str.encode()).hexdigest()[:12] def extract_recent_tool_signatures( messages: list[Message], lookback: int = 30 ) -> list[ToolCallSignature]: """Extract tool call signatures from recent assistant messages.""" signatures: list[ToolCallSignature] = [] recent = messages[-lookback:] if len(messages) > lookback else messages for msg in recent: if getattr(msg, "role", None) != "assistant": continue tool_calls = getattr(msg, "tool_calls", None) if not tool_calls: continue for tc in tool_calls: fn = getattr(tc, "function", None) if not fn: continue name = getattr(fn, "name", "") or "" args_str = getattr(fn, "arguments", "") or "" signatures.append(ToolCallSignature(name=name, args_hash=_hash_args(args_str))) return signatures def detect_identical_consecutive( signatures: list[ToolCallSignature], threshold: int = 3 ) -> str | None: """Return the tool name if threshold+ identical consecutive calls are found.""" if len(signatures) < threshold: return None count = 1 for i in range(1, len(signatures)): if signatures[i] == signatures[i - 1]: count += 1 if count >= threshold: return signatures[i].name else: count = 1 return None def detect_repeating_sequence( signatures: list[ToolCallSignature], ) -> list[ToolCallSignature] | None: """Detect repeating patterns like [A,B,A,B] for sequences of length 2-5 with 2+ reps.""" n = len(signatures) for seq_len in range(2, 6): min_required = seq_len * 2 if n < min_required: continue # Check the tail of the signatures list tail = signatures[-min_required:] pattern = tail[:seq_len] # Count how many full repetitions from the end reps = 0 for start in range(n - seq_len, -1, -seq_len): chunk = signatures[start : start + seq_len] if chunk == pattern: reps += 1 else: break if reps >= 2: return pattern return None def check_for_doom_loop(messages: list[Message]) -> str | None: """Check for doom loop patterns. Returns a corrective prompt or None.""" signatures = extract_recent_tool_signatures(messages, lookback=30) if len(signatures) < 3: return None # Check for identical consecutive calls tool_name = detect_identical_consecutive(signatures, threshold=3) if tool_name: logger.warning("Doom loop detected: %d+ identical consecutive calls to '%s'", 3, tool_name) return ( f"[SYSTEM: DOOM LOOP DETECTED] You have called '{tool_name}' with the same " f"arguments multiple times in a row, getting the same result each time. " f"STOP repeating this approach — it is not working. " f"Step back and try a fundamentally different strategy. " f"Consider: using a different tool, changing your arguments significantly, " f"or explaining to the user what you're stuck on and asking for guidance." ) # Check for repeating sequences pattern = detect_repeating_sequence(signatures) if pattern: pattern_desc = " → ".join(s.name for s in pattern) logger.warning("Doom loop detected: repeating sequence [%s]", pattern_desc) return ( f"[SYSTEM: DOOM LOOP DETECTED] You are stuck in a repeating cycle of tool calls: " f"[{pattern_desc}]. This pattern has repeated multiple times without progress. " f"STOP this cycle and try a fundamentally different approach. " f"Consider: breaking down the problem differently, using alternative tools, " f"or explaining to the user what you're stuck on and asking for guidance." ) return None ================================================ FILE: agent/core/effort_probe.py ================================================ """Probe-and-cascade for reasoning effort on /model switch. We don't maintain a per-model capability table. Instead, the first time a user picks a model we fire a 1-token ping with the same params we'd use for real and walk down a cascade (``max`` → ``xhigh`` → ``high`` → …) until the provider stops rejecting us. The result is cached per-model on the session, so real messages don't pay the probe cost again. Three outcomes, classified from the 400 error text: * success → cache the effort that worked * ``"thinking ... not supported"`` → model doesn't do thinking at all; cache ``None`` so we stop sending thinking params * ``"effort ... invalid"`` / synonyms → cascade walks down and retries Transient errors (5xx, timeout, connection reset) bubble out as ``ProbeInconclusive`` so the caller can complete the switch with a warning instead of blocking on a flaky provider. """ from __future__ import annotations import asyncio import logging from dataclasses import dataclass from litellm import acompletion from agent.core.llm_params import UnsupportedEffortError, _resolve_llm_params logger = logging.getLogger(__name__) # Cascade: for each user-stated preference, the ordered list of levels to # try. First success wins. ``max`` / ``xhigh`` are Anthropic-only; providers # that don't accept them raise ``UnsupportedEffortError`` synchronously (no # wasted network round-trip) and we advance to the next level. _EFFORT_CASCADE: dict[str, list[str]] = { "max": ["max", "xhigh", "high", "medium", "low"], "xhigh": ["xhigh", "high", "medium", "low"], "high": ["high", "medium", "low"], "medium": ["medium", "low"], "minimal": ["minimal", "low"], "low": ["low"], } _PROBE_TIMEOUT = 15.0 _PROBE_MAX_TOKENS = 16 class ProbeInconclusive(Exception): """The probe couldn't reach a verdict (transient network / provider error). Caller should complete the switch with a warning — the next real call will re-surface the error if it's persistent. """ @dataclass class ProbeOutcome: """What the probe learned. ``effective_effort`` semantics match the cache: * str → send this level * None → model doesn't support thinking; strip it """ effective_effort: str | None attempts: int elapsed_ms: int note: str | None = None # e.g. "max not supported, falling back" def _is_thinking_unsupported(e: Exception) -> bool: """Model rejected any thinking config. Matches Anthropic's 'thinking.type.enabled is not supported for this model' as well as the adaptive variant. Substring-match because the exact wording shifts across API versions. """ s = str(e).lower() return "thinking" in s and "not supported" in s def _is_invalid_effort(e: Exception) -> bool: """The requested effort level isn't accepted for this model. Covers both API responses (Anthropic/OpenAI 400 with "invalid", "must be one of", etc.) and LiteLLM's local validation that fires *before* the request (e.g. "effort='max' is only supported by Claude Opus 4.6" — LiteLLM knows max is Opus-4.6-only and raises synchronously). The cascade walks down on either. Explicitly returns False when the message is really about thinking itself (e.g. Anthropic's 4.7 error mentions ``output_config.effort`` in its fix hint, but the actual failure is ``thinking.type.enabled`` being unsupported). That case is caught by ``_is_thinking_unsupported``. """ if _is_thinking_unsupported(e): return False s = str(e).lower() if "effort" not in s and "output_config" not in s: return False return any( phrase in s for phrase in ( "invalid", "not supported", "must be one of", "not a valid", "unrecognized", "unknown", # LiteLLM's own pre-flight validation phrasing. "only supported by", "is only supported", ) ) def _is_transient(e: Exception) -> bool: """Network / provider-side flake. Keep in sync with agent_loop's list. Also matches by type for ``asyncio.TimeoutError`` — its ``str(e)`` is empty, so substring matching alone misses it. """ if isinstance(e, (asyncio.TimeoutError, TimeoutError)): return True s = str(e).lower() return any( p in s for p in ( "timeout", "timed out", "429", "rate limit", "503", "service unavailable", "502", "bad gateway", "500", "internal server error", "overloaded", "capacity", "connection reset", "connection refused", "connection error", "eof", "broken pipe", ) ) async def probe_effort( model_name: str, preference: str | None, hf_token: str | None, ) -> ProbeOutcome: """Walk the cascade for ``preference`` on ``model_name``. Returns the first effort the provider accepts, or ``None`` if it rejects thinking altogether. Raises ``ProbeInconclusive`` only for transient errors (5xx, timeout) — persistent 4xx that aren't thinking/ effort related bubble as the original exception so callers can surface them (auth, model-not-found, quota, etc.). """ loop = asyncio.get_event_loop() start = loop.time() attempts = 0 if not preference: # User explicitly turned effort off — nothing to probe. A bare # ping with no thinking params is pointless; just report "off". return ProbeOutcome(effective_effort=None, attempts=0, elapsed_ms=0) cascade = _EFFORT_CASCADE.get(preference, [preference]) skipped: list[str] = [] # levels the provider rejected synchronously last_error: Exception | None = None for effort in cascade: try: params = _resolve_llm_params( model_name, hf_token, reasoning_effort=effort, strict=True, ) except UnsupportedEffortError: # Provider can't even accept this effort name (e.g. "max" on # HF router). Skip without a network call. skipped.append(effort) continue attempts += 1 try: await asyncio.wait_for( acompletion( messages=[{"role": "user", "content": "ping"}], max_tokens=_PROBE_MAX_TOKENS, stream=False, **params, ), timeout=_PROBE_TIMEOUT, ) except Exception as e: last_error = e if _is_thinking_unsupported(e): elapsed = int((loop.time() - start) * 1000) return ProbeOutcome( effective_effort=None, attempts=attempts, elapsed_ms=elapsed, note="model doesn't support reasoning, dropped", ) if _is_invalid_effort(e): logger.debug("probe: %s rejected effort=%s, trying next", model_name, effort) continue if _is_transient(e): raise ProbeInconclusive(str(e)) from e # Persistent non-thinking 4xx (auth, quota, model-not-found) — # let the caller classify & surface. raise else: elapsed = int((loop.time() - start) * 1000) note = None if effort != preference: note = f"{preference} not supported, using {effort}" return ProbeOutcome( effective_effort=effort, attempts=attempts, elapsed_ms=elapsed, note=note, ) # Cascade exhausted without a success. This only happens when every # level was either rejected synchronously (``UnsupportedEffortError``, # e.g. preference=max on HF and we also somehow filtered all others) # or the provider 400'd ``invalid effort`` on every level. elapsed = int((loop.time() - start) * 1000) if last_error is not None and not _is_invalid_effort(last_error): raise last_error note = ( "no effort level accepted — proceeding without thinking" if not skipped else f"provider rejected all efforts ({', '.join(skipped)})" ) return ProbeOutcome( effective_effort=None, attempts=attempts, elapsed_ms=elapsed, note=note, ) ================================================ FILE: agent/core/hf_router_catalog.py ================================================ """Fetch and cache the HF Inference Router model catalog. The router exposes an OpenAI-compatible listing at ``https://router.huggingface.co/v1/models`` with per-provider availability, pricing, context length, and tool-use support. We use it to: • Validate ``/model`` switches with live data instead of a hard-coded allowlist. • Show the user which providers serve a model, at what price, and whether they support tool calls. • Derive a reasonable context-window limit for any routed model. The listing is cached in-memory for a few minutes so repeated lookups during a session are free. On fetch failure we return stale data if we have it, or an empty catalog otherwise. """ import logging import time from dataclasses import dataclass from difflib import get_close_matches from typing import Optional import httpx logger = logging.getLogger(__name__) _CATALOG_URL = "https://router.huggingface.co/v1/models" _CACHE_TTL_SECONDS = 300 _HTTP_TIMEOUT_SECONDS = 5.0 _cache: Optional[dict] = None _cache_time: float = 0.0 @dataclass class ProviderInfo: provider: str status: str context_length: Optional[int] input_price: Optional[float] output_price: Optional[float] supports_tools: bool supports_structured_output: bool @dataclass class ModelInfo: id: str providers: list[ProviderInfo] @property def live_providers(self) -> list[ProviderInfo]: return [p for p in self.providers if p.status == "live"] @property def max_context_length(self) -> Optional[int]: lengths = [p.context_length for p in self.live_providers if p.context_length] return max(lengths) if lengths else None @property def any_supports_tools(self) -> bool: return any(p.supports_tools for p in self.live_providers) def _fetch_catalog(force: bool = False) -> dict: global _cache, _cache_time now = time.time() if not force and _cache is not None and now - _cache_time < _CACHE_TTL_SECONDS: return _cache try: resp = httpx.get(_CATALOG_URL, timeout=_HTTP_TIMEOUT_SECONDS) resp.raise_for_status() _cache = resp.json() _cache_time = now except Exception as e: logger.warning("Failed to fetch HF router catalog: %s", e) if _cache is None: _cache = {"data": []} _cache_time = now return _cache def _parse_entry(entry: dict) -> ModelInfo: providers = [] for p in entry.get("providers", []) or []: pricing = p.get("pricing") or {} providers.append( ProviderInfo( provider=p.get("provider", ""), status=p.get("status", ""), context_length=p.get("context_length"), input_price=pricing.get("input"), output_price=pricing.get("output"), supports_tools=bool(p.get("supports_tools", False)), supports_structured_output=bool(p.get("supports_structured_output", False)), ) ) return ModelInfo(id=entry.get("id", ""), providers=providers) def lookup(model_id: str) -> Optional[ModelInfo]: """Find a model in the router catalog. Accepts ``/`` or ``/:`` — the tag is stripped for lookup. Returns ``None`` if the model isn't listed. """ bare = model_id.split(":", 1)[0] catalog = _fetch_catalog() for entry in catalog.get("data", []): if entry.get("id") == bare: return _parse_entry(entry) return None def fuzzy_suggest(model_id: str, limit: int = 3) -> list[str]: """Return the closest model ids from the catalog.""" bare = model_id.split(":", 1)[0] catalog = _fetch_catalog() ids = [e.get("id", "") for e in catalog.get("data", []) if e.get("id")] return get_close_matches(bare, ids, n=limit, cutoff=0.4) def prewarm() -> None: """Fetch the catalog so subsequent lookups are instant. Safe to call from a background task — swallows failures.""" try: _fetch_catalog(force=False) except Exception: pass ================================================ FILE: agent/core/llm_params.py ================================================ """LiteLLM kwargs resolution for the model ids this agent accepts. Kept separate from ``agent_loop`` so tools (research, context compaction, etc.) can import it without pulling in the whole agent loop / tool router and creating circular imports. """ import os def _patch_litellm_effort_validation() -> None: """Neuter LiteLLM 1.83's hardcoded effort-level validation. Context: at ``litellm/llms/anthropic/chat/transformation.py:~1443`` the Anthropic adapter validates ``output_config.effort ∈ {high, medium, low, max}`` and gates ``max`` behind an ``_is_opus_4_6_model`` check that only matches the substring ``opus-4-6`` / ``opus_4_6``. Result: * ``xhigh`` — valid on Anthropic's real API for Claude 4.7 — is rejected pre-flight with "Invalid effort value: xhigh". * ``max`` on Opus 4.7 is rejected with "effort='max' is only supported by Claude Opus 4.6", even though Opus 4.7 accepts it in practice. We don't want to maintain a parallel model table, so we let the Anthropic API itself be the validator: widen ``_is_opus_4_6_model`` to also match ``opus-4-7``+ families, and drop the valid-effort-set check entirely. If Anthropic rejects an effort level, we see a 400 and the cascade walks down — exactly the behavior we want for any future model family. Removable once litellm ships 1.83.8-stable (which merges PR #25867, "Litellm day 0 opus 4.7 support") — see commit 0868a82 on their main branch. Until then, this one-time patch is the escape hatch. """ try: from litellm.llms.anthropic.chat import transformation as _t except Exception: return cfg = getattr(_t, "AnthropicConfig", None) if cfg is None: return original = getattr(cfg, "_is_opus_4_6_model", None) if original is None or getattr(original, "_hf_agent_patched", False): return def _widened(model: str) -> bool: m = model.lower() # Original 4.6 match plus any future Opus >= 4.6. We only need this # to return True for families where "max" / "xhigh" are acceptable # at the API; the cascade handles the case when they're not. return any( v in m for v in ( "opus-4-6", "opus_4_6", "opus-4.6", "opus_4.6", "opus-4-7", "opus_4_7", "opus-4.7", "opus_4.7", ) ) _widened._hf_agent_patched = True # type: ignore[attr-defined] cfg._is_opus_4_6_model = staticmethod(_widened) _patch_litellm_effort_validation() # Effort levels accepted on the wire. # Anthropic (4.6+): low | medium | high | xhigh | max (output_config.effort) # OpenAI direct: minimal | low | medium | high (reasoning_effort top-level) # HF router: low | medium | high (extra_body.reasoning_effort) # # We validate *shape* here and let the probe cascade walk down on rejection; # we deliberately do NOT maintain a per-model capability table. _ANTHROPIC_EFFORTS = {"low", "medium", "high", "xhigh", "max"} _OPENAI_EFFORTS = {"minimal", "low", "medium", "high"} _HF_EFFORTS = {"low", "medium", "high"} class UnsupportedEffortError(ValueError): """The requested effort isn't valid for this provider's API surface. Raised synchronously before any network call so the probe cascade can skip levels the provider can't accept (e.g. ``max`` on HF router). """ def _resolve_llm_params( model_name: str, session_hf_token: str | None = None, reasoning_effort: str | None = None, strict: bool = False, ) -> dict: """ Build LiteLLM kwargs for a given model id. • ``anthropic/`` — native thinking config. We bypass LiteLLM's ``reasoning_effort`` → ``thinking`` mapping (which lags new Claude releases like 4.7 and sends the wrong API shape). Instead we pass both ``thinking={"type": "adaptive"}`` and ``output_config= {"effort": }`` as top-level kwargs — LiteLLM's Anthropic adapter forwards unknown top-level kwargs into the request body verbatim (confirmed by live probe; ``extra_body`` does NOT work here because Anthropic's API rejects it as "Extra inputs are not permitted"). This is the stable API for 4.6 and 4.7. Older extended-thinking models that only accept ``thinking.type.enabled`` will reject this; the probe's cascade catches that and falls back to no thinking. • ``openai/`` — ``reasoning_effort`` forwarded as a top-level kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``. • Anything else is treated as a HuggingFace router id. We hit the auto-routing OpenAI-compatible endpoint at ``https://router.huggingface.co/v1``. The id can be bare or carry an HF routing suffix (``:fastest`` / ``:cheapest`` / ``:``). A leading ``huggingface/`` is stripped. ``reasoning_effort`` is forwarded via ``extra_body`` (LiteLLM's OpenAI adapter refuses it as a top-level kwarg for non-OpenAI models). "minimal" normalizes to "low". ``strict=True`` raises ``UnsupportedEffortError`` when the requested effort isn't in the provider's accepted set, instead of silently dropping it. The probe cascade uses strict mode so it can walk down (``max`` → ``xhigh`` → ``high`` …) without making an API call. Regular runtime callers leave ``strict=False``, so a stale cached effort can't crash a turn — it just doesn't get sent. Token precedence (first non-empty wins): 1. INFERENCE_TOKEN env — shared key on the hosted Space (inference is free for users, billed to the Space owner via ``X-HF-Bill-To``). 2. session.hf_token — the user's own token (CLI / OAuth / cache file). 3. HF_TOKEN env — belt-and-suspenders fallback for CLI users. """ if model_name.startswith("anthropic/"): params: dict = {"model": model_name} if reasoning_effort: level = reasoning_effort if level == "minimal": level = "low" if level not in _ANTHROPIC_EFFORTS: if strict: raise UnsupportedEffortError( f"Anthropic doesn't accept effort={level!r}" ) else: # Adaptive thinking + output_config.effort is the stable # Anthropic API for Claude 4.6 / 4.7. Both kwargs are # passed top-level: LiteLLM forwards unknown params into # the request body for Anthropic, so ``output_config`` # reaches the API. ``extra_body`` does NOT work here — # Anthropic rejects it as "Extra inputs are not # permitted". params["thinking"] = {"type": "adaptive"} params["output_config"] = {"effort": level} return params if model_name.startswith("bedrock/"): # LiteLLM routes ``bedrock/...`` through the Converse adapter, which # picks up AWS credentials from the standard env vars # (``AWS_ACCESS_KEY_ID`` / ``AWS_SECRET_ACCESS_KEY`` / ``AWS_REGION``). # The Anthropic thinking/effort shape is not forwarded through Converse # the same way, so we leave it off for now. return {"model": model_name} if model_name.startswith("openai/"): params = {"model": model_name} if reasoning_effort: if reasoning_effort not in _OPENAI_EFFORTS: if strict: raise UnsupportedEffortError( f"OpenAI doesn't accept effort={reasoning_effort!r}" ) else: params["reasoning_effort"] = reasoning_effort return params hf_model = model_name.removeprefix("huggingface/") api_key = ( os.environ.get("INFERENCE_TOKEN") or session_hf_token or os.environ.get("HF_TOKEN") ) params = { "model": f"openai/{hf_model}", "api_base": "https://router.huggingface.co/v1", "api_key": api_key, } if os.environ.get("INFERENCE_TOKEN"): bill_to = os.environ.get("HF_BILL_TO", "smolagents") params["extra_headers"] = {"X-HF-Bill-To": bill_to} if reasoning_effort: hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort if hf_level not in _HF_EFFORTS: if strict: raise UnsupportedEffortError( f"HF router doesn't accept effort={hf_level!r}" ) else: params["extra_body"] = {"reasoning_effort": hf_level} return params ================================================ FILE: agent/core/model_switcher.py ================================================ """Model-switching logic for the interactive CLI's ``/model`` command. Split out of ``agent.main`` so the REPL dispatcher stays focused on input parsing. Exposes: * ``SUGGESTED_MODELS`` — the short list shown by ``/model`` with no arg. * ``is_valid_model_id`` — loose format check on user input. * ``probe_and_switch_model`` — async: checks routing, fires a 1-token probe to resolve the effort cascade, then commits the switch (or rejects it on hard error). The probe's cascade lives in ``agent.core.effort_probe``; this module glues it to CLI output + session state. """ from __future__ import annotations from agent.core.effort_probe import ProbeInconclusive, probe_effort # Suggested models shown by `/model` (not a gate). Users can paste any HF # model id (e.g. "MiniMaxAI/MiniMax-M2.7") or an `anthropic/` / `openai/` # prefix for direct API access. For HF ids, append ":fastest" / # ":cheapest" / ":preferred" / ":" to override the default # routing policy (auto = fastest with failover). SUGGESTED_MODELS = [ {"id": "bedrock/us.anthropic.claude-opus-4-7", "label": "Claude Opus 4.7"}, {"id": "bedrock/us.anthropic.claude-opus-4-6-v1", "label": "Claude Opus 4.6"}, {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, ] _ROUTING_POLICIES = {"fastest", "cheapest", "preferred"} def is_valid_model_id(model_id: str) -> bool: """Loose format check — lets users pick any model id. Accepts: • anthropic/ • openai//[:] (HF router; tag = provider or policy) • huggingface//[:] (same, accepts legacy prefix) Actual availability is verified against the HF router catalog on switch, and by the provider on the probe's ping call. """ if not model_id or "/" not in model_id: return False head = model_id.split(":", 1)[0] parts = head.split("/") return len(parts) >= 2 and all(parts) def _print_hf_routing_info(model_id: str, console) -> bool: """Show HF router catalog info (providers, price, context, tool support) for an HF-router model id. Returns ``True`` to signal the caller can proceed with the switch, ``False`` to indicate a hard problem the user should notice before we fire the effort probe. Anthropic / OpenAI ids return ``True`` without printing anything — the probe below covers "does this model exist". """ if model_id.startswith(("anthropic/", "openai/")): return True from agent.core import hf_router_catalog as cat bare, _, tag = model_id.partition(":") info = cat.lookup(bare) if info is None: console.print( f"[bold red]Warning:[/bold red] '{bare}' isn't in the HF router " "catalog. Checking anyway — first call may fail." ) suggestions = cat.fuzzy_suggest(bare) if suggestions: console.print(f"[dim]Did you mean: {', '.join(suggestions)}[/dim]") return True live = info.live_providers if not live: console.print( f"[bold red]Warning:[/bold red] '{bare}' has no live providers " "right now. First call will likely fail." ) return True if tag and tag not in _ROUTING_POLICIES: matched = [p for p in live if p.provider == tag] if not matched: names = ", ".join(p.provider for p in live) console.print( f"[bold red]Warning:[/bold red] provider '{tag}' doesn't serve " f"'{bare}'. Live providers: {names}. Checking anyway." ) if not info.any_supports_tools: console.print( f"[bold red]Warning:[/bold red] no provider for '{bare}' advertises " "tool-call support. This agent relies on tool calls — expect errors." ) if tag in _ROUTING_POLICIES: policy = tag elif tag: policy = f"pinned to {tag}" else: policy = "auto (fastest)" console.print(f" [dim]routing: {policy}[/dim]") for p in live: price = ( f"${p.input_price:g}/${p.output_price:g} per M tok" if p.input_price is not None and p.output_price is not None else "price n/a" ) ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a" tools = "tools" if p.supports_tools else "no tools" console.print( f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]" ) return True def print_model_listing(config, console) -> None: """Render the default ``/model`` (no-arg) view: current + suggested.""" current = config.model_name if config else "" console.print("[bold]Current model:[/bold]") console.print(f" {current}") console.print("\n[bold]Suggested:[/bold]") for m in SUGGESTED_MODELS: marker = " [dim]<-- current[/dim]" if m["id"] == current else "" console.print(f" {m['id']} [dim]({m['label']})[/dim]{marker}") console.print( "\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n" "Add ':fastest', ':cheapest', ':preferred', or ':' to override routing.\n" "Use 'anthropic/' or 'openai/' for direct API access.[/dim]" ) def print_invalid_id(arg: str, console) -> None: console.print(f"[bold red]Invalid model id format:[/bold red] {arg}") console.print( "[dim]Expected:\n" " • /[:tag] (HF router — paste from huggingface.co)\n" " • anthropic/\n" " • openai/[/dim]" ) async def probe_and_switch_model( model_id: str, config, session, console, hf_token: str | None, ) -> None: """Validate model+effort with a 1-token ping, cache the effective effort, then commit the switch. Three visible outcomes: * ✓ ``effort: `` — model accepted the preferred effort (or a fallback from the cascade; the note explains if so) * ✓ ``effort: off`` — model doesn't support thinking; we'll strip it * ✗ hard error (auth, model-not-found, quota) — we reject the switch and keep the current model so the user isn't stranded Transient errors (5xx, timeout) complete the switch with a yellow warning; the next real call re-surfaces the error if it's persistent. """ preference = config.reasoning_effort if not _print_hf_routing_info(model_id, console): return if not preference: # Nothing to validate with a ping that we couldn't validate on the # first real call just as cheaply. Skip the probe entirely. _commit_switch(model_id, config, session, effective=None, cache=False) console.print(f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]") return console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]") try: outcome = await probe_effort(model_id, preference, hf_token) except ProbeInconclusive as e: _commit_switch(model_id, config, session, effective=None, cache=False) console.print( f"[yellow]Model switched to {model_id}[/yellow] " f"[dim](couldn't validate: {e}; will verify on first message)[/dim]" ) return except Exception as e: # Hard persistent error — auth, unknown model, quota. Don't switch. console.print(f"[bold red]Switch failed:[/bold red] {e}") console.print(f"[dim]Keeping current model: {config.model_name}[/dim]") return _commit_switch( model_id, config, session, effective=outcome.effective_effort, cache=True, ) effort_label = outcome.effective_effort or "off" suffix = f" — {outcome.note}" if outcome.note else "" console.print( f"[green]Model switched to {model_id}[/green] " f"[dim](effort: {effort_label}{suffix}, {outcome.elapsed_ms}ms)[/dim]" ) def _commit_switch(model_id, config, session, effective, cache: bool) -> None: """Apply the switch to the session (or bare config if no session yet). ``effective`` is the probe's resolved effort; ``cache=True`` stores it in the session's per-model cache so real calls use the resolved level instead of re-probing. ``cache=False`` (inconclusive probe / effort off) leaves the cache untouched — next call falls back to preference. """ if session is not None: session.update_model(model_id) if cache: session.model_effective_effort[model_id] = effective else: session.model_effective_effort.pop(model_id, None) else: config.model_name = model_id ================================================ FILE: agent/core/prompt_caching.py ================================================ """Anthropic prompt caching breakpoints for outgoing LLM requests. Caching is GA on Anthropic's API and natively supported by litellm >=1.83 via ``cache_control`` blocks. We apply two breakpoints (out of 4 allowed): 1. The tool block — caches all tool definitions as a single prefix. 2. The system message — caches the rendered system prompt. Together these cover the ~4-5K static tokens that were being re-billed on every turn. Subsequent turns within the 5-minute TTL hit cache_read pricing (~10% of input cost) instead of full input. Non-Anthropic models (HF router, OpenAI) are passed through unchanged. """ from typing import Any def with_prompt_caching( messages: list[Any], tools: list[dict] | None, model_name: str | None, ) -> tuple[list[Any], list[dict] | None]: """Return (messages, tools) with cache_control breakpoints for Anthropic. No-op for non-Anthropic models. Original objects are not mutated; a fresh list with replaced first message and last tool is returned, so callers that share the underlying ``ContextManager.items`` list don't see their persisted history rewritten. """ if not model_name or "anthropic" not in model_name: return messages, tools if tools: new_tools = list(tools) last = dict(new_tools[-1]) last["cache_control"] = {"type": "ephemeral"} new_tools[-1] = last tools = new_tools if messages: first = messages[0] role = first.get("role") if isinstance(first, dict) else getattr(first, "role", None) if role == "system": content = ( first.get("content") if isinstance(first, dict) else getattr(first, "content", None) ) if isinstance(content, str) and content: cached_block = [{ "type": "text", "text": content, "cache_control": {"type": "ephemeral"}, }] new_first = {"role": "system", "content": cached_block} messages = [new_first] + list(messages[1:]) return messages, tools ================================================ FILE: agent/core/session.py ================================================ import asyncio import json import logging import subprocess import sys import uuid from dataclasses import dataclass from datetime import datetime from enum import Enum from pathlib import Path from typing import Any, Optional from agent.config import Config from agent.context_manager.manager import ContextManager logger = logging.getLogger(__name__) _DEFAULT_MAX_TOKENS = 200_000 def _get_max_tokens_safe(model_name: str) -> int: """Return the max input-context tokens for a model. Primary source: ``litellm.get_model_info(model)['max_input_tokens']`` — LiteLLM maintains an upstream catalog that knows Claude Opus 4.6 is 1M, GPT-5 is 272k, Sonnet 4.5 is 200k, and so on. Strips any HF routing suffix / huggingface/ prefix so tagged ids ('moonshotai/Kimi-K2.6:cheapest') look up the bare model. Falls back to a conservative 200k default for models not in the catalog (typically HF-router-only models). """ from litellm import get_model_info candidates = [model_name] stripped = model_name.removeprefix("huggingface/").split(":", 1)[0] if stripped != model_name: candidates.append(stripped) for candidate in candidates: try: info = get_model_info(candidate) max_input = info.get("max_input_tokens") if info else None if isinstance(max_input, int) and max_input > 0: return max_input except Exception: continue logger.info( "No litellm.get_model_info entry for %s, falling back to %d", model_name, _DEFAULT_MAX_TOKENS, ) return _DEFAULT_MAX_TOKENS class OpType(Enum): USER_INPUT = "user_input" EXEC_APPROVAL = "exec_approval" INTERRUPT = "interrupt" UNDO = "undo" COMPACT = "compact" SHUTDOWN = "shutdown" @dataclass class Event: event_type: str data: Optional[dict[str, Any]] = None class Session: """ Maintains agent session state Similar to Session in codex-rs/core/src/codex.rs """ def __init__( self, event_queue: asyncio.Queue, config: Config | None = None, tool_router=None, context_manager: ContextManager | None = None, hf_token: str | None = None, local_mode: bool = False, stream: bool = True, ): self.hf_token: Optional[str] = hf_token self.tool_router = tool_router self.stream = stream tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else [] self.context_manager = context_manager or ContextManager( model_max_tokens=_get_max_tokens_safe(config.model_name), compact_size=0.1, untouched_messages=5, tool_specs=tool_specs, hf_token=hf_token, local_mode=local_mode, ) self.event_queue = event_queue self.session_id = str(uuid.uuid4()) self.config = config or Config( model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0", ) self.is_running = True self._cancelled = asyncio.Event() self.pending_approval: Optional[dict[str, Any]] = None self.sandbox = None self._running_job_ids: set[str] = set() # HF job IDs currently executing # Session trajectory logging self.logged_events: list[dict] = [] self.session_start_time = datetime.now().isoformat() self.turn_count: int = 0 self.last_auto_save_turn: int = 0 # Per-model probed reasoning-effort cache. Populated by the probe # on /model switch, read by ``effective_effort_for`` below. Keys are # raw model ids (including any ``:tag``). Values: # str → the effort level to send (may be a downgrade from the # preference, e.g. "high" when user asked for "max") # None → model rejected all efforts in the cascade; send no # thinking params at all # Key absent → not probed yet; fall back to the raw preference. self.model_effective_effort: dict[str, str | None] = {} async def send_event(self, event: Event) -> None: """Send event back to client and log to trajectory""" await self.event_queue.put(event) # Log event to trajectory self.logged_events.append( { "timestamp": datetime.now().isoformat(), "event_type": event.event_type, "data": event.data, } ) def cancel(self) -> None: """Signal cancellation to the running agent loop.""" self._cancelled.set() def reset_cancel(self) -> None: """Clear the cancellation flag before a new run.""" self._cancelled.clear() @property def is_cancelled(self) -> bool: return self._cancelled.is_set() def update_model(self, model_name: str) -> None: """Switch the active model and update the context window limit.""" self.config.model_name = model_name self.context_manager.model_max_tokens = _get_max_tokens_safe(model_name) def effective_effort_for(self, model_name: str) -> str | None: """Resolve the effort level to actually send for ``model_name``. Returns the probed result when we have one (may be ``None`` meaning "model doesn't do thinking, strip it"), else the raw preference. Unknown-model case falls back to the preference so a stale cache from a prior ``/model`` can't poison research sub-calls that use a different model id. """ if model_name in self.model_effective_effort: return self.model_effective_effort[model_name] return self.config.reasoning_effort def increment_turn(self) -> None: """Increment turn counter (called after each user interaction)""" self.turn_count += 1 async def auto_save_if_needed(self) -> None: """Check if auto-save should trigger and save if so (completely non-blocking)""" if not self.config.save_sessions: return interval = self.config.auto_save_interval if interval <= 0: return turns_since_last_save = self.turn_count - self.last_auto_save_turn if turns_since_last_save >= interval: logger.info(f"Auto-saving session (turn {self.turn_count})...") # Fire-and-forget save - returns immediately self.save_and_upload_detached(self.config.session_dataset_repo) self.last_auto_save_turn = self.turn_count def get_trajectory(self) -> dict: """Serialize complete session trajectory for logging""" return { "session_id": self.session_id, "session_start_time": self.session_start_time, "session_end_time": datetime.now().isoformat(), "model_name": self.config.model_name, "messages": [msg.model_dump() for msg in self.context_manager.items], "events": self.logged_events, } def save_trajectory_local( self, directory: str = "session_logs", upload_status: str = "pending", dataset_url: Optional[str] = None, ) -> Optional[str]: """ Save trajectory to local JSON file as backup with upload status Args: directory: Directory to save logs (default: "session_logs") upload_status: Status of upload attempt ("pending", "success", "failed") dataset_url: URL of dataset if upload succeeded Returns: Path to saved file if successful, None otherwise """ try: log_dir = Path(directory) log_dir.mkdir(parents=True, exist_ok=True) trajectory = self.get_trajectory() # Add upload metadata trajectory["upload_status"] = upload_status trajectory["upload_url"] = dataset_url trajectory["last_save_time"] = datetime.now().isoformat() filename = f"session_{self.session_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" filepath = log_dir / filename with open(filepath, "w") as f: json.dump(trajectory, f, indent=2) return str(filepath) except Exception as e: logger.error(f"Failed to save session locally: {e}") return None def update_local_save_status( self, filepath: str, upload_status: str, dataset_url: Optional[str] = None ) -> bool: """Update the upload status of an existing local save file""" try: with open(filepath, "r") as f: data = json.load(f) data["upload_status"] = upload_status data["upload_url"] = dataset_url data["last_save_time"] = datetime.now().isoformat() with open(filepath, "w") as f: json.dump(data, f, indent=2) return True except Exception as e: logger.error(f"Failed to update local save status: {e}") return False def save_and_upload_detached(self, repo_id: str) -> Optional[str]: """ Save session locally and spawn detached subprocess for upload (fire-and-forget) Args: repo_id: HuggingFace dataset repo ID Returns: Path to local save file """ # Save locally first (fast, synchronous) local_path = self.save_trajectory_local(upload_status="pending") if not local_path: return None # Spawn detached subprocess for upload (fire-and-forget) try: uploader_script = Path(__file__).parent / "session_uploader.py" # Use Popen with detached process subprocess.Popen( [sys.executable, str(uploader_script), "upload", local_path, repo_id], stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True, # Detach from parent ) except Exception as e: logger.warning(f"Failed to spawn upload subprocess: {e}") return local_path @staticmethod def retry_failed_uploads_detached( directory: str = "session_logs", repo_id: Optional[str] = None ) -> None: """ Spawn detached subprocess to retry failed/pending uploads (fire-and-forget) Args: directory: Directory containing session logs repo_id: Target dataset repo ID """ if not repo_id: return try: uploader_script = Path(__file__).parent / "session_uploader.py" # Spawn detached subprocess for retry subprocess.Popen( [sys.executable, str(uploader_script), "retry", directory, repo_id], stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True, # Detach from parent ) except Exception as e: logger.warning(f"Failed to spawn retry subprocess: {e}") ================================================ FILE: agent/core/session_uploader.py ================================================ #!/usr/bin/env python3 """ Standalone script for uploading session trajectories to HuggingFace. This runs as a separate process to avoid blocking the main agent. Uses individual file uploads to avoid race conditions. """ import json import os import sys from datetime import datetime from pathlib import Path from dotenv import load_dotenv load_dotenv() # Token for session uploads — loaded from env var (never hardcode tokens in source) _SESSION_TOKEN = os.environ.get("HF_SESSION_UPLOAD_TOKEN", "") def upload_session_as_file( session_file: str, repo_id: str, max_retries: int = 3 ) -> bool: """ Upload a single session as an individual JSONL file (no race conditions) Args: session_file: Path to local session JSON file repo_id: HuggingFace dataset repo ID max_retries: Number of retry attempts Returns: True if successful, False otherwise """ try: from huggingface_hub import HfApi except ImportError: print("Error: huggingface_hub library not available", file=sys.stderr) return False try: # Load session data with open(session_file, "r") as f: data = json.load(f) # Check if already uploaded upload_status = data.get("upload_status") if upload_status == "success": return True # Use dedicated session upload token (write-only access to session dataset) hf_token = _SESSION_TOKEN if not hf_token: # Update status to failed data["upload_status"] = "failed" with open(session_file, "w") as f: json.dump(data, f, indent=2) return False # Prepare JSONL content (single line) # Store messages and events as JSON strings to avoid schema conflicts session_row = { "session_id": data["session_id"], "session_start_time": data["session_start_time"], "session_end_time": data["session_end_time"], "model_name": data["model_name"], "messages": json.dumps(data["messages"]), "events": json.dumps(data["events"]), } # Create temporary JSONL file import tempfile with tempfile.NamedTemporaryFile( mode="w", suffix=".jsonl", delete=False ) as tmp: json.dump(session_row, tmp) # Single line JSON tmp_path = tmp.name try: # Generate unique path in repo: sessions/YYYY-MM-DD/session_id.jsonl session_id = data["session_id"] date_str = datetime.fromisoformat(data["session_start_time"]).strftime( "%Y-%m-%d" ) repo_path = f"sessions/{date_str}/{session_id}.jsonl" # Upload with retries api = HfApi() for attempt in range(max_retries): try: # Try to create repo if it doesn't exist (idempotent) try: api.create_repo( repo_id=repo_id, repo_type="dataset", private=False, token=hf_token, exist_ok=True, # Don't fail if already exists ) except Exception: # Repo might already exist, continue pass # Upload the session file api.upload_file( path_or_fileobj=tmp_path, path_in_repo=repo_path, repo_id=repo_id, repo_type="dataset", token=hf_token, commit_message=f"Add session {session_id}", ) # Update local status to success data["upload_status"] = "success" data["upload_url"] = f"https://huggingface.co/datasets/{repo_id}" with open(session_file, "w") as f: json.dump(data, f, indent=2) return True except Exception: if attempt < max_retries - 1: import time wait_time = 2**attempt time.sleep(wait_time) else: # Final attempt failed data["upload_status"] = "failed" with open(session_file, "w") as f: json.dump(data, f, indent=2) return False finally: # Clean up temp file try: os.unlink(tmp_path) except Exception: pass except Exception as e: print(f"Error uploading session: {e}", file=sys.stderr) return False def retry_failed_uploads(directory: str, repo_id: str): """Retry all failed/pending uploads in a directory""" log_dir = Path(directory) if not log_dir.exists(): return session_files = list(log_dir.glob("session_*.json")) for filepath in session_files: try: with open(filepath, "r") as f: data = json.load(f) upload_status = data.get("upload_status", "unknown") # Only retry pending or failed uploads if upload_status in ["pending", "failed"]: upload_session_as_file(str(filepath), repo_id) except Exception: pass if __name__ == "__main__": if len(sys.argv) < 3: print("Usage: session_uploader.py ") sys.exit(1) command = sys.argv[1] if command == "upload": # python session_uploader.py upload if len(sys.argv) < 4: print("Usage: session_uploader.py upload ") sys.exit(1) session_file = sys.argv[2] repo_id = sys.argv[3] success = upload_session_as_file(session_file, repo_id) sys.exit(0 if success else 1) elif command == "retry": # python session_uploader.py retry if len(sys.argv) < 4: print("Usage: session_uploader.py retry ") sys.exit(1) directory = sys.argv[2] repo_id = sys.argv[3] retry_failed_uploads(directory, repo_id) sys.exit(0) else: print(f"Unknown command: {command}") sys.exit(1) ================================================ FILE: agent/core/tools.py ================================================ """ Tool system for the agent Provides ToolSpec and ToolRouter for managing both built-in and MCP tools """ import logging import warnings from dataclasses import dataclass from typing import Any, Awaitable, Callable, Optional logger = logging.getLogger(__name__) from fastmcp import Client from fastmcp.exceptions import ToolError from mcp.types import EmbeddedResource, ImageContent, TextContent from agent.config import MCPServerConfig from agent.tools.dataset_tools import ( HF_INSPECT_DATASET_TOOL_SPEC, hf_inspect_dataset_handler, ) from agent.tools.docs_tools import ( EXPLORE_HF_DOCS_TOOL_SPEC, HF_DOCS_FETCH_TOOL_SPEC, explore_hf_docs_handler, hf_docs_fetch_handler, ) from agent.tools.github_find_examples import ( GITHUB_FIND_EXAMPLES_TOOL_SPEC, github_find_examples_handler, ) from agent.tools.github_list_repos import ( GITHUB_LIST_REPOS_TOOL_SPEC, github_list_repos_handler, ) from agent.tools.github_read_file import ( GITHUB_READ_FILE_TOOL_SPEC, github_read_file_handler, ) from agent.tools.hf_repo_files_tool import ( HF_REPO_FILES_TOOL_SPEC, hf_repo_files_handler, ) from agent.tools.hf_repo_git_tool import ( HF_REPO_GIT_TOOL_SPEC, hf_repo_git_handler, ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler from agent.tools.sandbox_tool import get_sandbox_tools # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git # from agent.tools.private_hf_repo_tools import ( # PRIVATE_HF_REPO_TOOL_SPEC, # private_hf_repo_handler, # ) # Suppress aiohttp deprecation warning warnings.filterwarnings( "ignore", category=DeprecationWarning, module="aiohttp.connector" ) NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"] def convert_mcp_content_to_string(content: list) -> str: """ Convert MCP content blocks to a string format compatible with LLM messages. Based on FastMCP documentation, content can be: - TextContent: has .text field - ImageContent: has .data and .mimeType fields - EmbeddedResource: has .resource field with .text or .blob Args: content: List of MCP content blocks Returns: String representation of the content suitable for LLM consumption """ if not content: return "" parts = [] for item in content: if isinstance(item, TextContent): # Extract text from TextContent blocks parts.append(item.text) elif isinstance(item, ImageContent): # TODO: Handle images # For images, include a description with MIME type parts.append(f"[Image: {item.mimeType}]") elif isinstance(item, EmbeddedResource): # TODO: Handle embedded resources # For embedded resources, try to extract text resource = item.resource if hasattr(resource, "text") and resource.text: parts.append(resource.text) elif hasattr(resource, "blob") and resource.blob: parts.append( f"[Binary data: {resource.mimeType if hasattr(resource, 'mimeType') else 'unknown'}]" ) else: parts.append( f"[Resource: {resource.uri if hasattr(resource, 'uri') else 'unknown'}]" ) else: # Fallback: try to convert to string parts.append(str(item)) return "\n".join(parts) @dataclass class ToolSpec: """Tool specification for LLM""" name: str description: str parameters: dict[str, Any] handler: Optional[Callable[[dict[str, Any]], Awaitable[tuple[str, bool]]]] = None class ToolRouter: """ Routes tool calls to appropriate handlers. Based on codex-rs/core/src/tools/router.rs """ def __init__(self, mcp_servers: dict[str, MCPServerConfig], hf_token: str | None = None, local_mode: bool = False): self.tools: dict[str, ToolSpec] = {} self.mcp_servers: dict[str, dict[str, Any]] = {} for tool in create_builtin_tools(local_mode=local_mode): self.register_tool(tool) self.mcp_client: Client | None = None if mcp_servers: mcp_servers_payload = {} for name, server in mcp_servers.items(): data = server.model_dump() if hf_token: data.setdefault("headers", {})["Authorization"] = f"Bearer {hf_token}" mcp_servers_payload[name] = data self.mcp_client = Client({"mcpServers": mcp_servers_payload}) self._mcp_initialized = False def register_tool(self, tool: ToolSpec) -> None: self.tools[tool.name] = tool async def register_mcp_tools(self) -> None: tools = await self.mcp_client.list_tools() registered_names = [] skipped_count = 0 for tool in tools: if tool.name in NOT_ALLOWED_TOOL_NAMES: skipped_count += 1 continue registered_names.append(tool.name) self.register_tool( ToolSpec( name=tool.name, description=tool.description, parameters=tool.inputSchema, handler=None, ) ) logger.info( f"Loaded {len(registered_names)} MCP tools: {', '.join(registered_names)} ({skipped_count} disabled)" ) async def register_openapi_tool(self) -> None: """Register the OpenAPI search tool (requires async initialization)""" from agent.tools.docs_tools import ( _get_api_search_tool_spec, search_openapi_handler, ) try: openapi_spec = await _get_api_search_tool_spec() self.register_tool( ToolSpec( name=openapi_spec["name"], description=openapi_spec["description"], parameters=openapi_spec["parameters"], handler=search_openapi_handler, ) ) logger.info(f"Loaded OpenAPI search tool: {openapi_spec['name']}") except Exception as e: logger.warning("Failed to load OpenAPI search tool: %s", e) def get_tool_specs_for_llm(self) -> list[dict[str, Any]]: """Get tool specifications in OpenAI format""" specs = [] for tool in self.tools.values(): specs.append( { "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.parameters, }, } ) return specs async def __aenter__(self) -> "ToolRouter": if self.mcp_client is not None: try: await self.mcp_client.__aenter__() await self.mcp_client.initialize() await self.register_mcp_tools() self._mcp_initialized = True except Exception as e: logger.warning("MCP connection failed, continuing without MCP tools: %s", e) self.mcp_client = None await self.register_openapi_tool() total_tools = len(self.tools) logger.info(f"Agent ready with {total_tools} tools total") return self async def __aexit__(self, exc_type, exc, tb) -> None: if self.mcp_client is not None: await self.mcp_client.__aexit__(exc_type, exc, tb) self._mcp_initialized = False async def call_tool( self, tool_name: str, arguments: dict[str, Any], session: Any = None, tool_call_id: str | None = None, ) -> tuple[str, bool]: """ Call a tool and return (output_string, success_bool). For MCP tools, converts the CallToolResult content blocks to a string. For built-in tools, calls their handler directly. """ # Check if this is a built-in tool with a handler tool = self.tools.get(tool_name) if tool and tool.handler: import inspect # Check if handler accepts session argument sig = inspect.signature(tool.handler) if "session" in sig.parameters: # Check if handler also accepts tool_call_id parameter if "tool_call_id" in sig.parameters: return await tool.handler( arguments, session=session, tool_call_id=tool_call_id ) return await tool.handler(arguments, session=session) return await tool.handler(arguments) # Otherwise, use MCP client if self._mcp_initialized: try: result = await self.mcp_client.call_tool(tool_name, arguments) output = convert_mcp_content_to_string(result.content) return output, not result.is_error except ToolError as e: # Catch MCP tool errors and return them to the agent error_msg = f"Tool error: {str(e)}" return error_msg, False return "MCP client not initialized", False # ============================================================================ # BUILT-IN TOOL HANDLERS # ============================================================================ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: """Create built-in tool specifications""" # in order of importance tools = [ # Research sub-agent (delegates to read-only tools in independent context) ToolSpec( name=RESEARCH_TOOL_SPEC["name"], description=RESEARCH_TOOL_SPEC["description"], parameters=RESEARCH_TOOL_SPEC["parameters"], handler=research_handler, ), # Documentation search tools ToolSpec( name=EXPLORE_HF_DOCS_TOOL_SPEC["name"], description=EXPLORE_HF_DOCS_TOOL_SPEC["description"], parameters=EXPLORE_HF_DOCS_TOOL_SPEC["parameters"], handler=explore_hf_docs_handler, ), ToolSpec( name=HF_DOCS_FETCH_TOOL_SPEC["name"], description=HF_DOCS_FETCH_TOOL_SPEC["description"], parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"], handler=hf_docs_fetch_handler, ), # Paper discovery and reading ToolSpec( name=HF_PAPERS_TOOL_SPEC["name"], description=HF_PAPERS_TOOL_SPEC["description"], parameters=HF_PAPERS_TOOL_SPEC["parameters"], handler=hf_papers_handler, ), # Dataset inspection tool (unified) ToolSpec( name=HF_INSPECT_DATASET_TOOL_SPEC["name"], description=HF_INSPECT_DATASET_TOOL_SPEC["description"], parameters=HF_INSPECT_DATASET_TOOL_SPEC["parameters"], handler=hf_inspect_dataset_handler, ), # Planning and job management tools ToolSpec( name=PLAN_TOOL_SPEC["name"], description=PLAN_TOOL_SPEC["description"], parameters=PLAN_TOOL_SPEC["parameters"], handler=plan_tool_handler, ), ToolSpec( name=HF_JOBS_TOOL_SPEC["name"], description=HF_JOBS_TOOL_SPEC["description"], parameters=HF_JOBS_TOOL_SPEC["parameters"], handler=hf_jobs_handler, ), # HF Repo management tools ToolSpec( name=HF_REPO_FILES_TOOL_SPEC["name"], description=HF_REPO_FILES_TOOL_SPEC["description"], parameters=HF_REPO_FILES_TOOL_SPEC["parameters"], handler=hf_repo_files_handler, ), ToolSpec( name=HF_REPO_GIT_TOOL_SPEC["name"], description=HF_REPO_GIT_TOOL_SPEC["description"], parameters=HF_REPO_GIT_TOOL_SPEC["parameters"], handler=hf_repo_git_handler, ), ToolSpec( name=GITHUB_FIND_EXAMPLES_TOOL_SPEC["name"], description=GITHUB_FIND_EXAMPLES_TOOL_SPEC["description"], parameters=GITHUB_FIND_EXAMPLES_TOOL_SPEC["parameters"], handler=github_find_examples_handler, ), ToolSpec( name=GITHUB_LIST_REPOS_TOOL_SPEC["name"], description=GITHUB_LIST_REPOS_TOOL_SPEC["description"], parameters=GITHUB_LIST_REPOS_TOOL_SPEC["parameters"], handler=github_list_repos_handler, ), ToolSpec( name=GITHUB_READ_FILE_TOOL_SPEC["name"], description=GITHUB_READ_FILE_TOOL_SPEC["description"], parameters=GITHUB_READ_FILE_TOOL_SPEC["parameters"], handler=github_read_file_handler, ), ] # Sandbox or local tools (highest priority) if local_mode: from agent.tools.local_tools import get_local_tools tools = get_local_tools() + tools else: tools = get_sandbox_tools() + tools tool_names = ", ".join([t.name for t in tools]) logger.info(f"Loaded {len(tools)} built-in tools: {tool_names}") return tools ================================================ FILE: agent/main.py ================================================ """ Interactive CLI chat with the agent Supports two modes: Interactive: python -m agent.main Headless: python -m agent.main "find me bird datasets" """ import argparse import asyncio import json import os import signal import sys import time from dataclasses import dataclass from pathlib import Path from typing import Any, Optional import litellm from prompt_toolkit import PromptSession from agent.config import load_config from agent.core.agent_loop import submission_loop from agent.core import model_switcher from agent.core.session import OpType from agent.core.tools import ToolRouter from agent.utils.reliability_checks import check_training_script_save_pattern from agent.utils.terminal_display import ( get_console, print_approval_header, print_approval_item, print_banner, print_compacted, print_error, print_help, print_init_done, print_interrupted, print_markdown, print_plan, print_tool_call, print_tool_log, print_tool_output, print_turn_complete, print_yolo_approve, ) litellm.drop_params = True # Suppress the "Give Feedback / Get Help" banner LiteLLM prints to stderr # on every error — users don't need it, and our friendly errors cover the case. litellm.suppress_debug_info = True def _safe_get_args(arguments: dict) -> dict: """Safely extract args dict from arguments, handling cases where LLM passes string.""" args = arguments.get("args", {}) # Sometimes LLM passes args as string instead of dict if isinstance(args, str): return {} return args if isinstance(args, dict) else {} def _get_hf_token() -> str | None: """Get HF token from environment, huggingface_hub API, or cached token file.""" token = os.environ.get("HF_TOKEN") if token: return token try: from huggingface_hub import HfApi api = HfApi() token = api.token if token: return token except Exception: pass # Fallback: read the cached token file directly token_path = Path.home() / ".cache" / "huggingface" / "token" if token_path.exists(): token = token_path.read_text().strip() if token: return token return None async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str: """Prompt user for HF token, validate it, save via huggingface_hub.login(). Loops until valid.""" from prompt_toolkit.formatted_text import HTML from huggingface_hub import HfApi, login print("\nA Hugging Face token is required.") print("Get one at: https://huggingface.co/settings/tokens\n") while True: try: token = await prompt_session.prompt_async( HTML("Paste your HF token: ") ) except (EOFError, KeyboardInterrupt): print("\nToken is required to continue.") continue token = token.strip() if not token: print("Token cannot be empty.") continue # Validate token against the API try: api = HfApi(token=token) user_info = api.whoami() username = user_info.get("name", "unknown") print(f"Token valid (user: {username})") except Exception: print("Invalid token. Please try again.") continue # Save for future sessions try: login(token=token, add_to_git_credential=False) print("Token saved to ~/.cache/huggingface/token") except Exception as e: print(f"Warning: could not persist token ({e}), using for this session only.") return token @dataclass class Operation: """Operation to be executed by the agent""" op_type: OpType data: Optional[dict[str, Any]] = None @dataclass class Submission: """Submission to the agent loop""" id: str operation: Operation def _create_rich_console(): """Get the shared rich Console.""" return get_console() class _ThinkingShimmer: """Animated shiny/shimmer thinking indicator — a bright gradient sweeps across the text.""" _BASE = (90, 90, 110) # dim base color _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold) _WIDTH = 5 # shimmer width in characters _FPS = 24 def __init__(self, console): self._console = console self._task = None self._running = False def start(self): if self._running: return self._running = True self._task = asyncio.ensure_future(self._animate()) def stop(self): if not self._running: return # no-op when never started (e.g. headless mode) self._running = False if self._task: self._task.cancel() self._task = None # Clear the shimmer line self._console.file.write("\r\033[K") self._console.file.flush() def _render_frame(self, text: str, offset: float) -> str: """Render one frame: a bright spot sweeps left-to-right across `text`.""" out = [] n = len(text) for i, ch in enumerate(text): # Distance from the shimmer center (wraps around) dist = abs(i - offset) wrap_dist = abs(i - offset + n + self._WIDTH) dist = min(dist, wrap_dist, abs(i - offset - n - self._WIDTH)) # Blend factor: 1.0 at center, 0.0 beyond _WIDTH t = max(0.0, 1.0 - dist / self._WIDTH) t = t * t * (3 - 2 * t) # smoothstep r = int(self._BASE[0] + (self._HIGHLIGHT[0] - self._BASE[0]) * t) g = int(self._BASE[1] + (self._HIGHLIGHT[1] - self._BASE[1]) * t) b = int(self._BASE[2] + (self._HIGHLIGHT[2] - self._BASE[2]) * t) out.append(f"\033[38;2;{r};{g};{b}m{ch}") out.append("\033[0m") return "".join(out) async def _animate(self): text = "Thinking..." n = len(text) speed = 0.45 # characters per frame pos = 0.0 try: while self._running: frame = self._render_frame(text, pos) self._console.file.write(f"\r {frame}") self._console.file.flush() pos = (pos + speed) % (n + self._WIDTH) await asyncio.sleep(1.0 / self._FPS) except asyncio.CancelledError: pass class _StreamBuffer: """Accumulates streamed tokens, renders markdown block-by-block as complete blocks appear. A "block" is everything up to a paragraph break (\\n\\n). Unclosed code fences (odd count of ```) hold back flushing until closed so a code block is always rendered as one unit.""" def __init__(self, console): self._console = console self._buffer = "" def add_chunk(self, text: str): self._buffer += text def _pop_block(self) -> str | None: """Extract the next complete block, or return None if nothing complete.""" if self._buffer.count("```") % 2 == 1: return None # inside an open code fence — wait for close idx = self._buffer.find("\n\n") if idx == -1: return None block = self._buffer[:idx] self._buffer = self._buffer[idx + 2:] return block async def flush_ready( self, cancel_event: "asyncio.Event | None" = None, instant: bool = False, ): """Render any complete blocks that have accumulated; leave the tail.""" while True: if cancel_event is not None and cancel_event.is_set(): return block = self._pop_block() if block is None: return if block.strip(): await print_markdown(block, cancel_event=cancel_event, instant=instant) async def finish( self, cancel_event: "asyncio.Event | None" = None, instant: bool = False, ): """Flush complete blocks, then render whatever incomplete tail remains.""" await self.flush_ready(cancel_event=cancel_event, instant=instant) if self._buffer.strip(): await print_markdown(self._buffer, cancel_event=cancel_event, instant=instant) self._buffer = "" def discard(self): self._buffer = "" async def event_listener( event_queue: asyncio.Queue, submission_queue: asyncio.Queue, turn_complete_event: asyncio.Event, ready_event: asyncio.Event, prompt_session: PromptSession, config=None, session_holder=None, ) -> None: """Background task that listens for events and displays them""" submission_id = [1000] last_tool_name = [None] console = _create_rich_console() shimmer = _ThinkingShimmer(console) stream_buf = _StreamBuffer(console) def _cancel_event(): """Return the session's cancellation Event so print_markdown can abort its typewriter loop mid-stream when Ctrl+C fires.""" s = session_holder[0] if session_holder else None return s._cancelled if s is not None else None while True: try: event = await event_queue.get() if event.event_type == "ready": tool_count = event.data.get("tool_count", 0) if event.data else 0 print_init_done(tool_count=tool_count) ready_event.set() elif event.event_type == "assistant_message": shimmer.stop() content = event.data.get("content", "") if event.data else "" if content: await print_markdown(content, cancel_event=_cancel_event()) elif event.event_type == "assistant_chunk": content = event.data.get("content", "") if event.data else "" if content: stream_buf.add_chunk(content) # Flush any complete markdown blocks progressively so the # user sees paragraphs appear as they're produced, not just # at the end of the whole response. shimmer.stop() await stream_buf.flush_ready(cancel_event=_cancel_event()) elif event.event_type == "assistant_stream_end": shimmer.stop() await stream_buf.finish(cancel_event=_cancel_event()) elif event.event_type == "tool_call": shimmer.stop() stream_buf.discard() tool_name = event.data.get("tool", "") if event.data else "" arguments = event.data.get("arguments", {}) if event.data else {} if tool_name: last_tool_name[0] = tool_name # Skip printing research tool_call — the tool_log handler shows it if tool_name != "research": args_str = json.dumps(arguments)[:80] print_tool_call(tool_name, args_str) elif event.event_type == "tool_output": output = event.data.get("output", "") if event.data else "" success = event.data.get("success", False) if event.data else False # Only show output for plan_tool — everything else is noise if last_tool_name[0] == "plan_tool" and output: print_tool_output(output, success, truncate=False) shimmer.start() elif event.event_type == "turn_complete": shimmer.stop() stream_buf.discard() print_turn_complete() print_plan() turn_complete_event.set() elif event.event_type == "interrupted": shimmer.stop() stream_buf.discard() print_interrupted() turn_complete_event.set() elif event.event_type == "undo_complete": console.print("[dim]Undone.[/dim]") turn_complete_event.set() elif event.event_type == "tool_log": tool = event.data.get("tool", "") if event.data else "" log = event.data.get("log", "") if event.data else "" if log: agent_id = event.data.get("agent_id", "") if event.data else "" label = event.data.get("label", "") if event.data else "" print_tool_log(tool, log, agent_id=agent_id, label=label) elif event.event_type == "tool_state_change": pass # visual noise — approval flow handles this elif event.event_type == "error": shimmer.stop() stream_buf.discard() error = event.data.get("error", "Unknown error") if event.data else "Unknown error" print_error(error) turn_complete_event.set() elif event.event_type == "shutdown": shimmer.stop() stream_buf.discard() break elif event.event_type == "processing": shimmer.start() elif event.event_type == "compacted": old_tokens = event.data.get("old_tokens", 0) if event.data else 0 new_tokens = event.data.get("new_tokens", 0) if event.data else 0 print_compacted(old_tokens, new_tokens) elif event.event_type == "approval_required": # Handle batch approval format tools_data = event.data.get("tools", []) if event.data else [] count = event.data.get("count", 0) if event.data else 0 # If yolo mode is active, auto-approve everything if config and config.yolo_mode: approvals = [ { "tool_call_id": t.get("tool_call_id", ""), "approved": True, "feedback": None, } for t in tools_data ] print_yolo_approve(count) submission_id[0] += 1 approval_submission = Submission( id=f"approval_{submission_id[0]}", operation=Operation( op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}, ), ) await submission_queue.put(approval_submission) continue print_approval_header(count) approvals = [] # Ask for approval for each tool for i, tool_info in enumerate(tools_data, 1): tool_name = tool_info.get("tool", "") arguments = tool_info.get("arguments", {}) tool_call_id = tool_info.get("tool_call_id", "") # Handle case where arguments might be a JSON string if isinstance(arguments, str): try: arguments = json.loads(arguments) except json.JSONDecodeError: print(f"Warning: Failed to parse arguments for {tool_name}") arguments = {} operation = arguments.get("operation", "") print_approval_item(i, count, tool_name, operation) # Handle different tool types if tool_name == "hf_jobs": # Check if this is Python mode (script) or Docker mode (command) script = arguments.get("script") command = arguments.get("command") if script: # Python mode dependencies = arguments.get("dependencies", []) python_version = arguments.get("python") script_args = arguments.get("script_args", []) # Show full script print(f"Script:\n{script}") if dependencies: print(f"Dependencies: {', '.join(dependencies)}") if python_version: print(f"Python version: {python_version}") if script_args: print(f"Script args: {' '.join(script_args)}") # Run reliability checks on the full script (not truncated) check_message = check_training_script_save_pattern(script) if check_message: print(check_message) elif command: # Docker mode image = arguments.get("image", "python:3.12") command_str = ( " ".join(command) if isinstance(command, list) else str(command) ) print(f"Docker image: {image}") print(f"Command: {command_str}") # Common parameters for jobs hardware_flavor = arguments.get("hardware_flavor", "cpu-basic") timeout = arguments.get("timeout", "30m") env = arguments.get("env", {}) schedule = arguments.get("schedule") print(f"Hardware: {hardware_flavor}") print(f"Timeout: {timeout}") if env: env_keys = ", ".join(env.keys()) print(f"Environment variables: {env_keys}") if schedule: print(f"Schedule: {schedule}") elif tool_name == "hf_private_repos": # Handle private repo operations args = _safe_get_args(arguments) if operation in ["create_repo", "upload_file"]: repo_id = args.get("repo_id", "") repo_type = args.get("repo_type", "dataset") # Build repo URL type_path = "" if repo_type == "model" else f"{repo_type}s" repo_url = ( f"https://huggingface.co/{type_path}/{repo_id}".replace( "//", "/" ) ) print(f"Repository: {repo_id}") print(f"Type: {repo_type}") print("Private: Yes") print(f"URL: {repo_url}") # Show file preview for upload_file operation if operation == "upload_file": path_in_repo = args.get("path_in_repo", "") file_content = args.get("file_content", "") print(f"File: {path_in_repo}") if isinstance(file_content, str): # Calculate metrics all_lines = file_content.split("\n") line_count = len(all_lines) size_bytes = len(file_content.encode("utf-8")) size_kb = size_bytes / 1024 size_mb = size_kb / 1024 print(f"Line count: {line_count}") if size_kb < 1024: print(f"Size: {size_kb:.2f} KB") else: print(f"Size: {size_mb:.2f} MB") # Show preview preview_lines = all_lines[:5] preview = "\n".join(preview_lines) print( f"Content preview (first 5 lines):\n{preview}" ) if len(all_lines) > 5: print("...") elif tool_name == "hf_repo_files": # Handle repo files operations (upload, delete) repo_id = arguments.get("repo_id", "") repo_type = arguments.get("repo_type", "model") revision = arguments.get("revision", "main") # Build repo URL if repo_type == "model": repo_url = f"https://huggingface.co/{repo_id}" else: repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}" print(f"Repository: {repo_id}") print(f"Type: {repo_type}") print(f"Branch: {revision}") print(f"URL: {repo_url}") if operation == "upload": path = arguments.get("path", "") content = arguments.get("content", "") create_pr = arguments.get("create_pr", False) print(f"File: {path}") if create_pr: print("Mode: Create PR") if isinstance(content, str): all_lines = content.split("\n") line_count = len(all_lines) size_bytes = len(content.encode("utf-8")) size_kb = size_bytes / 1024 print(f"Lines: {line_count}") if size_kb < 1024: print(f"Size: {size_kb:.2f} KB") else: print(f"Size: {size_kb / 1024:.2f} MB") # Show full content print(f"Content:\n{content}") elif operation == "delete": patterns = arguments.get("patterns", []) if isinstance(patterns, str): patterns = [patterns] print(f"Patterns to delete: {', '.join(patterns)}") elif tool_name == "hf_repo_git": # Handle git operations (branches, tags, PRs, repo management) repo_id = arguments.get("repo_id", "") repo_type = arguments.get("repo_type", "model") # Build repo URL if repo_type == "model": repo_url = f"https://huggingface.co/{repo_id}" else: repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}" print(f"Repository: {repo_id}") print(f"Type: {repo_type}") print(f"URL: {repo_url}") if operation == "delete_branch": branch = arguments.get("branch", "") print(f"Branch to delete: {branch}") elif operation == "delete_tag": tag = arguments.get("tag", "") print(f"Tag to delete: {tag}") elif operation == "merge_pr": pr_num = arguments.get("pr_num", "") print(f"PR to merge: #{pr_num}") elif operation == "create_repo": private = arguments.get("private", False) space_sdk = arguments.get("space_sdk") print(f"Private: {private}") if space_sdk: print(f"Space SDK: {space_sdk}") elif operation == "update_repo": private = arguments.get("private") gated = arguments.get("gated") if private is not None: print(f"Private: {private}") if gated is not None: print(f"Gated: {gated}") # Get user decision for this item. Ctrl+C / EOF here is # treated as "reject remaining" (matches Codex's modal # priority and Forgecode's approval-cancel path). Without # this, KeyboardInterrupt kills the event listener and # the main loop deadlocks waiting for turn_complete. try: response = await prompt_session.prompt_async( f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): " ) except (KeyboardInterrupt, EOFError): get_console().print("[dim]Approval cancelled — rejecting remaining items[/dim]") approvals.append( { "tool_call_id": tool_call_id, "approved": False, "feedback": "User cancelled approval", } ) for remaining in tools_data[i:]: approvals.append( { "tool_call_id": remaining.get("tool_call_id", ""), "approved": False, "feedback": None, } ) break response = response.strip().lower() # Handle yolo mode activation if response == "yolo": config.yolo_mode = True print( "YOLO MODE ACTIVATED - Auto-approving all future tool calls" ) # Auto-approve this item and all remaining approvals.append( { "tool_call_id": tool_call_id, "approved": True, "feedback": None, } ) for remaining in tools_data[i:]: approvals.append( { "tool_call_id": remaining.get("tool_call_id", ""), "approved": True, "feedback": None, } ) break approved = response in ["y", "yes"] feedback = None if approved or response in ["n", "no"] else response approvals.append( { "tool_call_id": tool_call_id, "approved": approved, "feedback": feedback, } ) # Submit batch approval submission_id[0] += 1 approval_submission = Submission( id=f"approval_{submission_id[0]}", operation=Operation( op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}, ), ) await submission_queue.put(approval_submission) console.print() # spacing after approval # Silently ignore other events except asyncio.CancelledError: break except Exception as e: print(f"Event listener error: {e}") async def get_user_input(prompt_session: PromptSession) -> str: """Get user input asynchronously""" from prompt_toolkit.formatted_text import HTML return await prompt_session.prompt_async(HTML("\n> ")) # ── Slash command helpers ──────────────────────────────────────────────── # Slash commands are defined in terminal_display async def _handle_slash_command( cmd: str, config, session_holder: list, submission_queue: asyncio.Queue, submission_id: list[int], ) -> Submission | None: """ Handle a slash command. Returns a Submission to enqueue, or None if the command was handled locally (caller should set turn_complete_event). Async because ``/model`` fires a probe ping to validate the model+effort combo before committing the switch. """ parts = cmd.strip().split(None, 1) command = parts[0].lower() arg = parts[1].strip() if len(parts) > 1 else "" if command == "/help": print_help() return None if command == "/undo": submission_id[0] += 1 return Submission( id=f"sub_{submission_id[0]}", operation=Operation(op_type=OpType.UNDO), ) if command == "/compact": submission_id[0] += 1 return Submission( id=f"sub_{submission_id[0]}", operation=Operation(op_type=OpType.COMPACT), ) if command == "/model": console = get_console() if not arg: model_switcher.print_model_listing(config, console) return None if not model_switcher.is_valid_model_id(arg): model_switcher.print_invalid_id(arg, console) return None normalized = arg.removeprefix("huggingface/") session = session_holder[0] if session_holder else None await model_switcher.probe_and_switch_model( normalized, config, session, console, _get_hf_token(), ) return None if command == "/yolo": config.yolo_mode = not config.yolo_mode state = "ON" if config.yolo_mode else "OFF" print(f"YOLO mode: {state}") return None if command == "/effort": console = get_console() valid = {"minimal", "low", "medium", "high", "xhigh", "max", "off"} session = session_holder[0] if session_holder else None if not arg: current = config.reasoning_effort or "off" console.print(f"[bold]Reasoning effort preference:[/bold] {current}") if session and session.model_effective_effort: console.print("[dim]Probed per model:[/dim]") for m, eff in session.model_effective_effort.items(): console.print(f" [dim]{m}: {eff or 'off'}[/dim]") console.print( "[dim]Set with '/effort minimal|low|medium|high|xhigh|max|off'. " "'max' and 'xhigh' are Anthropic-only; the cascade falls back " "to whatever the model actually accepts.[/dim]" ) return None level = arg.lower() if level not in valid: console.print(f"[bold red]Invalid level:[/bold red] {arg}") console.print(f"[dim]Expected one of: {', '.join(sorted(valid))}[/dim]") return None config.reasoning_effort = None if level == "off" else level # Drop the per-model probe cache — the new preference may resolve # differently. Next ``/model`` (or the retry safety net) reprobes. if session is not None: session.model_effective_effort.clear() console.print(f"[green]Reasoning effort: {level}[/green]") if session is not None: console.print( "[dim]run /model to re-probe, or send a message — " "the agent adjusts automatically if the new level isn't supported.[/dim]" ) return None if command == "/status": session = session_holder[0] if session_holder else None print(f"Model: {config.model_name}") print(f"Reasoning effort: {config.reasoning_effort or 'off'}") if session: print(f"Turns: {session.turn_count}") print(f"Context items: {len(session.context_manager.items)}") return None print(f"Unknown command: {command}. Type /help for available commands.") return None async def main(): """Interactive chat with the agent""" # Clear screen os.system("clear" if os.name != "nt" else "cls") # Create prompt session for input (needed early for token prompt) prompt_session = PromptSession() # HF token — required, prompt if missing hf_token = _get_hf_token() if not hf_token: hf_token = await _prompt_and_save_hf_token(prompt_session) # Resolve username for banner hf_user = None try: from huggingface_hub import HfApi hf_user = HfApi(token=hf_token).whoami().get("name") except Exception: pass print_banner(hf_user=hf_user) # Pre-warm the HF router catalog in the background so /model switches # don't block on a network fetch. from agent.core import hf_router_catalog asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm)) # Create queues for communication submission_queue = asyncio.Queue() event_queue = asyncio.Queue() # Events to signal agent state turn_complete_event = asyncio.Event() turn_complete_event.set() ready_event = asyncio.Event() # Start agent loop in background config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json" config = load_config(config_path) # Create tool router with local mode tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) # Session holder for interrupt/model/status access session_holder = [None] agent_task = asyncio.create_task( submission_loop( submission_queue, event_queue, config=config, tool_router=tool_router, session_holder=session_holder, hf_token=hf_token, local_mode=True, stream=True, ) ) # Start event listener in background listener_task = asyncio.create_task( event_listener( event_queue, submission_queue, turn_complete_event, ready_event, prompt_session, config, session_holder=session_holder, ) ) await ready_event.wait() submission_id = [0] # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137 # (`QUIT_SHORTCUT_TIMEOUT = Duration::from_secs(1)`). Two Ctrl+C presses # within this window quit; a single press cancels the in-flight turn. CTRL_C_QUIT_WINDOW = 1.0 # Hint string matches codex-rs/tui/src/bottom_pane/footer.rs:746 # (`" again to quit"` prefixed with the key binding, rendered dim). CTRL_C_HINT = "[dim]ctrl + c again to quit[/dim]" interrupt_state = {"last": 0.0, "exit": False} loop = asyncio.get_running_loop() def _on_sigint() -> None: """SIGINT handler — fires while the agent is generating (terminal is in cooked mode between prompts). Mirrors Codex's `on_ctrl_c` in codex-rs/tui/src/chatwidget.rs: first press cancels active work and arms the quit hint; second press within the window quits.""" now = time.monotonic() session = session_holder[0] if now - interrupt_state["last"] < CTRL_C_QUIT_WINDOW: interrupt_state["exit"] = True if session: session.cancel() # Wake the main loop out of turn_complete_event.wait() turn_complete_event.set() return interrupt_state["last"] = now if session and not session.is_cancelled: session.cancel() get_console().print(f"\n{CTRL_C_HINT}") def _install_sigint() -> bool: try: loop.add_signal_handler(signal.SIGINT, _on_sigint) return True except (NotImplementedError, RuntimeError): return False # Windows or non-main thread # prompt_toolkit's prompt_async installs its own SIGINT handler and, on # exit, calls loop.remove_signal_handler(SIGINT) — which wipes ours too. # So we re-arm at the top of every loop iteration, right before the busy # wait. Without this, Ctrl+C during agent streaming after the first turn # falls through to the default handler and the terminal just echoes ^C. sigint_available = _install_sigint() try: while True: if sigint_available: _install_sigint() try: await turn_complete_event.wait() except asyncio.CancelledError: break turn_complete_event.clear() if interrupt_state["exit"]: break # Get user input. prompt_toolkit puts the terminal in raw mode and # installs its own SIGINT handling; ^C arrives as \x03 and surfaces # as KeyboardInterrupt here. On return, prompt_toolkit removes the # loop's SIGINT handler — we re-arm at the top of the next iter. try: user_input = await get_user_input(prompt_session) except EOFError: break except KeyboardInterrupt: now = time.monotonic() if now - interrupt_state["last"] < CTRL_C_QUIT_WINDOW: break interrupt_state["last"] = now get_console().print(CTRL_C_HINT) turn_complete_event.set() continue # A successful read ends the double-press window — an unrelated # Ctrl+C during the next turn should start a fresh arming. interrupt_state["last"] = 0.0 # Check for exit commands if user_input.strip().lower() in ["exit", "quit", "/quit", "/exit"]: break # Skip empty input if not user_input.strip(): turn_complete_event.set() continue # Handle slash commands if user_input.strip().startswith("/"): sub = await _handle_slash_command( user_input.strip(), config, session_holder, submission_queue, submission_id ) if sub is None: # Command handled locally, loop back for input turn_complete_event.set() continue else: await submission_queue.put(sub) continue # Submit to agent submission_id[0] += 1 submission = Submission( id=f"sub_{submission_id[0]}", operation=Operation( op_type=OpType.USER_INPUT, data={"text": user_input} ), ) await submission_queue.put(submission) except KeyboardInterrupt: pass finally: if sigint_available: try: loop.remove_signal_handler(signal.SIGINT) except (NotImplementedError, RuntimeError): pass # Shutdown shutdown_submission = Submission( id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN) ) await submission_queue.put(shutdown_submission) # Wait for agent to finish (the listener must keep draining events # or the agent will block on event_queue.put) try: await asyncio.wait_for(agent_task, timeout=10.0) except asyncio.TimeoutError: agent_task.cancel() # Agent didn't shut down cleanly — close MCP explicitly await tool_router.__aexit__(None, None, None) # Now safe to cancel the listener (agent is done emitting events) listener_task.cancel() get_console().print("\n[dim]Bye.[/dim]\n") async def headless_main( prompt: str, model: str | None = None, max_iterations: int | None = None, stream: bool = True, ) -> None: """Run a single prompt headlessly and exit.""" import logging logging.basicConfig(level=logging.WARNING) hf_token = _get_hf_token() if not hf_token: print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr) sys.exit(1) print(f"HF token loaded", file=sys.stderr) config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json" config = load_config(config_path) config.yolo_mode = True # Auto-approve everything in headless mode if model: config.model_name = model if max_iterations is not None: config.max_iterations = max_iterations print(f"Model: {config.model_name}", file=sys.stderr) print(f"Max iterations: {config.max_iterations}", file=sys.stderr) print(f"Prompt: {prompt}", file=sys.stderr) print("---", file=sys.stderr) submission_queue: asyncio.Queue = asyncio.Queue() event_queue: asyncio.Queue = asyncio.Queue() tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) session_holder: list = [None] agent_task = asyncio.create_task( submission_loop( submission_queue, event_queue, config=config, tool_router=tool_router, session_holder=session_holder, hf_token=hf_token, local_mode=True, stream=stream, ) ) # Wait for ready while True: event = await event_queue.get() if event.event_type == "ready": break # Submit the prompt submission = Submission( id="sub_1", operation=Operation(op_type=OpType.USER_INPUT, data={"text": prompt}), ) await submission_queue.put(submission) # Process events until turn completes. Headless mode is for scripts / # log capture: no shimmer animation, no typewriter, no live-redrawing # research overlay. Output is plain, append-only text. console = _create_rich_console() stream_buf = _StreamBuffer(console) _hl_last_tool = [None] _hl_sub_id = [1] # Research sub-agent tool calls are buffered per agent_id and dumped as # a static block once each sub-agent finishes, instead of streaming via # the live redrawing SubAgentDisplayManager (which is TTY-only). _hl_research_buffers: dict[str, dict] = {} while True: event = await event_queue.get() if event.event_type == "assistant_chunk": content = event.data.get("content", "") if event.data else "" if content: stream_buf.add_chunk(content) await stream_buf.flush_ready(instant=True) elif event.event_type == "assistant_stream_end": await stream_buf.finish(instant=True) elif event.event_type == "assistant_message": content = event.data.get("content", "") if event.data else "" if content: await print_markdown(content, instant=True) elif event.event_type == "tool_call": stream_buf.discard() tool_name = event.data.get("tool", "") if event.data else "" arguments = event.data.get("arguments", {}) if event.data else {} if tool_name: _hl_last_tool[0] = tool_name if tool_name != "research": args_str = json.dumps(arguments)[:80] print_tool_call(tool_name, args_str) elif event.event_type == "tool_output": output = event.data.get("output", "") if event.data else "" success = event.data.get("success", False) if event.data else False if _hl_last_tool[0] == "plan_tool" and output: print_tool_output(output, success, truncate=False) elif event.event_type == "tool_log": tool = event.data.get("tool", "") if event.data else "" log = event.data.get("log", "") if event.data else "" if not log: pass elif tool == "research": # Headless mode: buffer research sub-agent activity per-agent, # then dump each as a static block on completion. The live # SubAgentDisplayManager uses terminal cursor tricks that are # unfit for non-TTY output, but parallel agents still need # distinct output so we key buffers by agent_id. agent_id = event.data.get("agent_id", "") if event.data else "" label = event.data.get("label", "") if event.data else "" aid = agent_id or "research" if log == "Starting research sub-agent...": _hl_research_buffers[aid] = { "label": label or "research", "calls": [], } elif log == "Research complete.": buf = _hl_research_buffers.pop(aid, None) if buf is not None: f = get_console().file f.write(f" \033[38;2;255;200;80m▸ {buf['label']}\033[0m\n") for call in buf["calls"]: f.write(f" \033[2m{call}\033[0m\n") f.flush() elif log.startswith("tokens:") or log.startswith("tools:"): pass # stats updates — only useful for the live display elif aid in _hl_research_buffers: _hl_research_buffers[aid]["calls"].append(log) else: # Orphan event (Start was missed) — fall back to raw print print_tool_log(tool, log, agent_id=agent_id, label=label) else: print_tool_log(tool, log) elif event.event_type == "approval_required": # Auto-approve everything in headless mode (safety net if yolo_mode # didn't prevent the approval event for some reason) tools_data = event.data.get("tools", []) if event.data else [] approvals = [ { "tool_call_id": t.get("tool_call_id", ""), "approved": True, "feedback": None, } for t in tools_data ] _hl_sub_id[0] += 1 await submission_queue.put(Submission( id=f"hl_approval_{_hl_sub_id[0]}", operation=Operation( op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}, ), )) elif event.event_type == "compacted": old_tokens = event.data.get("old_tokens", 0) if event.data else 0 new_tokens = event.data.get("new_tokens", 0) if event.data else 0 print_compacted(old_tokens, new_tokens) elif event.event_type == "error": stream_buf.discard() error = event.data.get("error", "Unknown error") if event.data else "Unknown error" print_error(error) break elif event.event_type in ("turn_complete", "interrupted"): stream_buf.discard() history_size = event.data.get("history_size", "?") if event.data else "?" print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr) break # Shutdown shutdown_submission = Submission( id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN) ) await submission_queue.put(shutdown_submission) try: await asyncio.wait_for(agent_task, timeout=10.0) except asyncio.TimeoutError: agent_task.cancel() await tool_router.__aexit__(None, None, None) def cli(): """Entry point for the ml-intern CLI command.""" import logging as _logging import warnings # Suppress aiohttp "Unclosed client session" noise during event loop teardown _logging.getLogger("asyncio").setLevel(_logging.CRITICAL) # Suppress litellm pydantic deprecation warnings warnings.filterwarnings("ignore", category=DeprecationWarning, module="litellm") # Suppress whoosh invalid escape sequence warnings (third-party, unfixed upstream) warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh") parser = argparse.ArgumentParser(description="Hugging Face Agent CLI") parser.add_argument("prompt", nargs="?", default=None, help="Run headlessly with this prompt") parser.add_argument("--model", "-m", default=None, help=f"Model to use (default: from config)") parser.add_argument("--max-iterations", type=int, default=None, help="Max LLM requests per turn (default: 50, use -1 for unlimited)") parser.add_argument("--no-stream", action="store_true", help="Disable token streaming (use non-streaming LLM calls)") args = parser.parse_args() try: if args.prompt: max_iter = args.max_iterations if max_iter is not None and max_iter < 0: max_iter = 10_000 # effectively unlimited asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream)) else: asyncio.run(main()) except KeyboardInterrupt: print("\n\nGoodbye!") if __name__ == "__main__": cli() ================================================ FILE: agent/prompts/system_prompt.yaml ================================================ system_prompt: | You are Hugging Face Agent, a skilled AI assistant for machine learning engineering. Hugging Face is a company that provides two main services : libraries to write deep learning tasks, and resources (models, datasets, compute) to execute them. You will aid users to do these tasks, interacting with the Hugging Face stack via {{ num_tools }}. # General behavior Your main goal is to achieve what the user asked. For this proactive in the quantity of actions taken. However, never make big decisions in place of the user. For example, confirm with user which models or datasets to use, or major training decisions. # Task Approach. **CRITICAL : Research first, Then Implement** For ANY implementation task (training, fine-tuning, inference, data processing, etc.), you should proceed in these three mandatory steps: 1. **FIRST**: Search HF documentation to find the correct approach. - Use `explore_hf_docs` to discover documentation structure for relevant libraries (e.g., "trl", "transformers", "diffusers"). - Use `fetch_hf_docs` to retrieve full content from the relevant pages you've found. - Use `search_hf_api_endpoints` to find API endpoints with usage examples. - Skip ONLY for simple factual questions (e.g., "What is LoRA?") 2. **THEN**: Formulate a plan based on research findings. Pass todos to the PlanTool. Update frequently to show when progress is made. This will also help you decompose hard tasks. 3. **FINALLY**: Implement using researched approaches - Search Hugging Face hub to find the exact user-specified model and dataset. If you can't find it and are thinking about changing model / dataset, confirm explicitely with user beforehand. - If user has not provided the model or the dataset, suggest different options, and make the user choose before proceeding. - Use all available tools to complete the task. - Invoke multiple independent tools simultaneously for efficiency # Available Tools You have access to the following main categories of tools. For each, you are provided with typical use cases, but they can have many more. - Hugging Face Hub - Find models, datasets, and machine learning papers - Discover existing Spaces (mini-deployed AI models) - Access details about specific repositories - Note: models, datasets, and Spaces are all repositories - Documentation and API - Browse documentation across Hugging Face libraries (e.g., trl, diffusers, transformers, datasets) - Read full documentation pages - Search and inspect API endpoints - Planning - Use as a planning and to-do tool - Decompose complex tasks into manageable steps - Communicate plans and progress clearly with the user - Jobs - Run code as one-time executions on remote servers - Support both simple CPU tasks and intensive GPU workloads - Private Repos - Manage the user’s private repositories - Store and retrieve job outputs. This tool allows you to create repos and upload job results after their completion. - Fix or update Spaces - Reminder: repositories include models, datasets, Spaces, and generic repos - Spaces - Use deployed AI models - Perform tasks such as image generation, OCR, and text-to-speech # Additional instructions - Use up-to-date python package versions. This is important. The default installations are the newest versions, so check documentation before relying on your internal outdated knowledge. - Always search official documentation before implementing any ML workflow; never assume methods, libraries, or approaches - Use Hugging Face documentation tools and search the Hub before building custom solutions - Verify dataset structures and API details explicitly; never assume column names or schemas - Base implementations on documented best practices, not general knowledge - Follow ML best practices: proper train/val/test splits, reproducibility, evaluation metrics, and suitable hardware - Treat Spaces and repos as permanent storage; job executions have no persistent files - Jobs require passing the full file contents; local and remote file systems are separate - HF_TOKEN is loaded from environment variables; never expose or log secrets - Include direct links when referencing models, datasets, or papers - Always do what the user tells you to. # Communication style - Be concise and direct. - Don't flatter the user. - Never use emojis nor exclamation points. - If you are limited in a task, offer alternatives. - Don't thank the user when he provides results. - Explain what you're doing for non-trivial operations. - If the user asks something, answer. User questions take precedent over task completion. - Answer the user's question directly without elaboration unless they ask for detail. One word answers are best when appropriate. # Examples User: Fine-tune a Llama-style model for instruction following on a custom dataset. Assistant: 1. Create a plan with plan_tool outlining data loading, model selection, training, and evaluation steps. 2. Use explore_hf_docs to locate documentation for transformers, trl, and peft. 3. Use fetch_hf_docs to read the relevant documentation more precisely. 4. Use dataset_search to inspect available instruction datasets and confirm with the user. 5. Use model_search to find compatible base models and confirm choice. 6. Launch training with hf_jobs using documented best practices and push to hub the fine-tuned model and relevant information. User: My Space crashes on startup. Can you fix it? Assistant: 1. Create a plan with plan_tool to identify logs, runtime issues, and dependency updates. 2. Use hub_repo_details to inspect the Space repository and logs. 3. Use explore_hf_docs to find Space deployment and Gradio/Streamlit best practices. 4. Update files in the Space repo using hf_private_repos. 5. Restart and verify the Space. User: Find a good dataset for image captioning and summarize its structure. Assistant: 1. Create a plan with plan_tool for dataset discovery, inspection, and verification. 2. Use dataset_search with tags such as "image-captioning". 3. Use hub_repo_details to inspect candidate datasets. 4. Verify column names, splits, and licensing explicitly. 5. Report findings concisely and include direct links. User: Generate images using a fast text-to-image model. Assistant: 1. Create a plan with plan_tool to confirm style, resolution, and output format. 2. Use gr1_z_image_turbo_generate with the provided prompt. 3. Return generated images without additional commentary. User: Run inference with a specific text classification model on my text file. Assistant: 1. Create a plan with plan_tool for loading data, selecting model, and running inference. 2. Use model_search to locate the exact model and confirm with the user. 3. Use explore_hf_docs and fetch_hf_docs to find the correct inference API. 4. Execute the script with hf_jobs. User: Is there recent research on parameter-efficient fine-tuning? Assistant: 1. Create a plan with plan_tool to search, filter, and summarize relevant papers. 2. Use paper_search with semantic queries related to PEFT. 3. Identify relevant papers and verify publication details. 4. Summarize key findings briefly and include direct links. User: Build a small demo that does OCR on images. Assistant: 1. Create a plan with plan_tool to define input, OCR method, and demo output. 2. Use space_search to find existing OCR Spaces for reference. 3. Use explore_hf_docs to review OCR-related pipelines. 4. Implement using dynamic_space to execute OCR tasks. User: What models are trending right now for speech recognition? Assistant: 1. Create a plan with plan_tool to filter models by task and relevance. 2. Use model_search with task filters for speech recognition. 3. Sort by trending or downloads. 4. Report top results with short descriptions and links. ================================================ FILE: agent/prompts/system_prompt_v2.yaml ================================================ system_prompt: | You are Hugging Face Agent, a skilled AI assistant for machine learning engineering with deep expertise in the Hugging Face ecosystem. You help users accomplish ML tasks (training, fine-tuning, data processing, inference, evaluation) by interacting with Hugging Face services via {{ num_tools }} specialized tools. _Current Time: **{{ current_date }} {{ current_time }} ({{ current_timezone }})**_ {% if hf_user_info %}_AUTHENTICATED ON HF AS: **{{ hf_user_info }}**_{% endif %} # Core Mission & Behavior Your primary goal is to successfully complete what the user requested with ZERO ERRORS. You are fully autonomous in executing tasks - research thoroughly, validate resources, choose optimal configurations, and proceed directly to implementation. **Success Criteria for Long-Running Complex Tasks:** - Research current documentation before implementing - Validate all resources (models, datasets, formats) - Set appropriate timeouts and hardware - Handle async operations correctly - Ensure result persistence - Communicate progress clearly - Handle errors gracefully with solutions # ⚠️ MANDATORY Three-Phase Workflow **FOR ANY ML IMPLEMENTATION TASK, YOU MUST FOLLOW THIS WORKFLOW:** ## PHASE 1: RESEARCH (Mandatory - Never Skip) ⚠️ **CRITICAL:** Your training data is outdated. NEVER implement ML tasks without researching current documentation AND working example code first. **Use the `research` tool.** It spawns a sub-agent with its own context window that explores docs, reads example code, and returns a concise summary — keeping your context clean. ```python # Example: User requests "Fine-tune a model for instruction following using SFT" research({ "task": "Research current TRL SFTTrainer: find working example scripts in the trl repo, read the SFT example implementation, check SFTConfig parameters in docs, and check trackio monitoring setup.", "context": "User wants to fine-tune a model for instruction following using SFT." }) # Returns: key findings, code patterns, imports, config parameters, file references ``` **Be specific in your research task** — include library names, trainer types, dataset names, specific questions. The sub-agent knows how to use github_find_examples, github_read_file, explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, and hf_papers. **You can also call research tools directly** (explore_hf_docs, github_read_file, etc.) for quick lookups that don't need a full research cycle. **Skip research ONLY for:** - Simple factual questions ("What is LoRA?", "What is DPO?") - Status checks (`hf_jobs("ps")`, `hf_jobs("logs", job_id="xxx")`) - Resource discovery (`model_search`, `dataset_search`, `paper_search`) - Trivial operations that don't require implementation ## PHASE 2: PLAN & VALIDATE (Required for Multi-Step Tasks) ⚠️ **CRITICAL:** Break down complex tasks and validate resources BEFORE executing. ### Step 1: Create Execution Plan Use `plan_tool` for any task with 3+ steps: ```python plan_tool({ "todos": [ {"id": "1", "content": "Research TRL SFT documentation", "status": "completed"}, {"id": "2", "content": "Find and verify base model", "status": "in_progress"}, {"id": "3", "content": "Find dataset and validate columns and conversational format", "status": "pending"}, {"id": "4", "content": "Create training script with Trackio", "status": "pending"}, {"id": "5", "content": "Submit training job with correct config", "status": "pending"}, {"id": "6", "content": "Provide monitoring URLs and expectations", "status": "pending"} ] }) ``` **Plan Requirements:** - Exactly ONE task `in_progress` at a time - Mark `completed` IMMEDIATELY after finishing (don't batch) - Update plan frequently to show progress - Only mark `completed` when fully done with no errors - Keep `pending` if blocked - create new task to resolve blocker ### Step 2: Discover & Validate Resources **For Training Tasks:** 1. ✅ **Find base model:** ```python model_search({"query": "qwen3 4b instuct", "sort": "downloads", "limit": 5}) ``` 2. ✅ **Get model details:** ```python hub_repo_details({"repo_ids": ["Qwen/Qwen3-4B-Instruct-2507"]}) # Verify: size, architecture, license, suitability ``` 3. ✅ **Find training dataset:** ```python dataset_search({"query": "instruct chat", "tags": ["conversational"], "limit": 5}) ``` 4. ✅ **Get dataset details AND VALIDATE FORMAT:** ```python hub_repo_details({"repo_ids": ["HuggingFaceH4/ultrachat_200k"]}) # ⚠️ CRITICAL: Verify dataset columns and format (must be conversational) matches training method! # - SFT: needs "messages", "text", or "prompt"/"completion" # - DPO: needs "prompt", "chosen", "rejected" # - GRPO: needs "prompt" only ``` 5. ✅ **Select optimal resources:** - Choose most suitable model for task (size, quality, performance balance) if the user has not specified a model - Select appropriate dataset with verified format compatibility if the user has not specified a dataset - Determine optimal hardware based on model size and budget efficiency - Proceed directly to implementation after validation **Dataset Format Validation is CRITICAL:** - Training will FAIL if format doesn't match method and is not conversational - ALWAYS check with `hub_repo_details` before training - Different training methods have different requirements - Validate format matches method before proceeding **For Data Processing Tasks:** 1. ✅ Find dataset with `dataset_search` 2. ✅ Verify structure with `hub_repo_details` 3. ✅ Determine optimal processing approach based on requirements 4. ✅ Plan output format and destination ## PHASE 3: IMPLEMENT (Execute with Researched Approaches) ### For Training Tasks ⚠️ **TRAINING REQUIREMENTS CHECKLIST:** **Before Submission:** - [ ] Researched current TRL documentation - [ ] Found and verified base model - [ ] Found dataset and VALIDATED columns and conversational format matches method - [ ] Selected optimal model + dataset + hardware configuration - [ ] Created plan with plan_tool - [ ] Researched Trackio monitoring setup **Training Script MUST Include:** - [ ] Imports from researched documentation (current APIs) - [ ] Trackio initialization with project/run_name/config - [ ] Model and tokenizer loading - [ ] Dataset loading with verified columns and conversational format - [ ] Training config with ALL critical settings: - `push_to_hub=True` ⚠️ MANDATORY - `hub_model_id="username/model-name"` ⚠️ MANDATORY - `report_to=["trackio"]` (for monitoring) - `output_dir="./output"` - `num_train_epochs`, `per_device_train_batch_size`, `learning_rate` - `logging_steps`, `save_steps` - `max_length` if needed (default 1024 usually fine) - [ ] Trainer initialization with model, args, dataset, tokenizer - [ ] `trainer.train()` call - [ ] `trainer.push_to_hub()` at end ⚠️ MANDATORY - [ ] `tracker.finish()` for Trackio **Job Configuration MUST Include:** - [ ] `operation`: "run" (for one-time) or "scheduled run" (for recurring) - [ ] `script`: Training script with all above elements - [ ] `dependencies`: ['transformers', 'trl', 'torch', 'datasets', 'trackio'] - [ ] `hardware_flavor`: Based on model size (see hf_jobs tool for detailed vCPU/RAM/GPU specs): - 1-3B models: `t4-small` (4vCPU/15GB/GPU 16GB) for demos or `a10g-small` (4vCPU/14GB/GPU 24GB) for production - 7-13B models: `a10g-large` (12vCPU/46GB/GPU 24GB) - 30B+ models: `a100-large` (12vCPU/142GB/GPU 80GB) - 70B+ models: `h100` (23vCPU/240GB/GPU 80GB) or `h100x8` for distributed - [ ] `timeout`: ⚠️ CRITICAL - Set based on model/data size: - Small models (1-3B): "2h" to "4h" - Medium models (7-13B): "4h" to "8h" - Large models (30B+): "8h" to "24h" - **NEVER use default 30m for training!** ### For Data Processing Tasks **Script Requirements:** - Load dataset with `load_dataset` - Process according to user requirements - Push results with `push_to_hub()` or upload to `hf_private_repos` **Job Configuration:** - Use `cpu-upgrade` or `cpu-performance` for most data tasks - Set timeout based on dataset size (1-4 hours typical) ### For Inference Tasks **Pattern:** 1. Research inference approach in docs 2. Find model with `model_search` + `hub_repo_details` 3. Create inference script with pipeline or generate 4. Submit with `hf_jobs` on appropriate hardware 5. Provide monitoring info ### For Evaluation Tasks **Pattern:** 1. Research evaluation framework (lighteval, lm-evaluation-harness) 2. Find model to evaluate 3. Create evaluation script 4. Submit job with appropriate hardware 5. Store results with `hf_private_repos` # Tool Usage Patterns for Reliability ## Research Use the `research` tool for any ML implementation research. It handles the full github_find_examples → github_read_file → explore_hf_docs → fetch_hf_docs chain in its own context and returns a summary. You can also call these tools directly for quick lookups. ## Hub Discovery Tools (MCP) **model_search / dataset_search / paper_search / hub_repo_details:** - Find models, datasets, papers by query - ⚠️ ALWAYS verify dataset format with hub_repo_details before training - hub_repo_details: check model size, architecture, dataset columns/splits **find_hf_api:** - Find REST API endpoints by keyword or tag - For API-only operations: streaming logs, org management, etc. ## Execution & Storage Tools **hf_jobs:** - Execute workloads on cloud infrastructure with detailed hardware specs (vCPU/RAM/GPU) - ⚠️ Set timeout >30m (default too short) - ⚠️ Include HF_TOKEN for Hub operations - ⚠️ Storage is EPHEMERAL - must push_to_hub **hf_private_repos:** - Store job outputs persistently in datasets with push_to_hub (jobs lose files after completion) - Upload logs, scripts, results that can't push_to_hub - Create private repos for sensitive data - Content-based: pass strings/bytes, not file paths - After upload: provide repo URL to user **plan_tool:** - Break down complex tasks (3+ steps) - Update frequently to show progress - Exactly ONE task in_progress at a time - Mark completed immediately after finishing ## Space Tools (MCP) **space_search:** - Find deployed Spaces (demos, applications) - Discover existing implementations **use_space:** - Give user access to a Space - Returns link for user (may not be visible to you) **dynamic_space:** - Execute tasks using Space functionality - Image generation, OCR, text-to-speech, etc. - Only works with MCP-enabled Spaces # Ground Rules for Reliability ## Async Operations (Jobs, Long Tasks) **✓ DO:** - Poll logs automatically after submission to ensure job is running and works as expected - Include Trackio dashboard URL for training jobs - Note that user can check status later - Explain what's happening in the background **✗ DON'T:** - Check status unless user asks - Assume job will complete quickly ## Resource Selection **✓ DO:** - Research and evaluate 3-5 options for models/datasets - Assess key details (size, format, popularity, suitability) - Select optimal option based on task requirements and efficiency - ALWAYS validate dataset format matches training method before proceeding - Choose hardware that balances cost and performance **✗ DON'T:** - Skip research and validation steps - Assume most popular is automatically best for task - Proceed with training without format validation - Select unnecessarily expensive hardware without justification ## Documentation Usage **✓ DO:** - Use `research` tool before implementing any ML task - Base implementation on the research findings (code patterns, imports, config) **✗ DON'T:** - Implement based on internal knowledge without researching first - Assume you know current API syntax - Skip research for "simple" ML tasks ## Error Handling & Recovery **When Errors Occur:** 1. ✅ Keep task in `in_progress` status (don't mark complete) 2. ✅ Create new todo for resolving the issue 3. ✅ Explain error clearly with technical details 4. ✅ Provide actionable solution based on error type 5. ✅ Check documentation if API/syntax error 6. ✅ Verify configuration if job fails 7. ✅ Implement fix and retry automatically with corrected approach **Common Issues & Solutions:** ### Job Timeout Exceeded **Symptom:** Job stops mid-execution, incomplete **Cause:** Timeout too short for workload **Solution:** ```python # ✗ WRONG: Default timeout {"timeout": "30m"} # Too short for training! # ✓ CORRECT: Appropriate timeout {"timeout": "4h"} # For 1-3B model training {"timeout": "8h"} # For 7-13B model training ``` ### Model Not Pushed to Hub **Symptom:** Training completes but model not on Hub **Causes & Solutions:** 1. Missing `push_to_hub=True` in training config 2. Missing `hub_model_id` in training config 3. Missing `HF_TOKEN` in job env 4. Token lacks write permissions **Solution:** ```python # Training config: training_args = SFTConfig( push_to_hub=True, # ← Must be True hub_model_id="username/model-name", # ← Must be set # ... ) ``` ### Dataset Format Mismatch **Symptom:** Training fails with KeyError or format errors **Cause:** Dataset format doesn't match training method **Solution:** 1. Use `hub_repo_details` to inspect dataset structure 2. Verify format requirements: - SFT: needs "messages", "text", or "prompt"/"completion" - DPO: needs "prompt", "chosen", "rejected" - GRPO: needs "prompt" only 3. Preprocess dataset to correct format 4. Proceed with corrected configuration ### Out of Memory (OOM) **Symptom:** Job crashes with CUDA OOM error **Solutions (in order of preference):** 1. Increase `gradient_accumulation_steps` (compensates smaller batch) 2. Reduce `per_device_train_batch_size` (try 4 → 2 → 1) 3. Enable `gradient_checkpointing=True` 4. Reduce `max_length` (e.g., 1024 → 512) 5. Upgrade to larger GPU (t4 → a10g → a100 → h100) # Communication Style - Be concise and direct - Don't flatter the user - Don't use emojis in regular communication (okay in status messages like "✅ Job submitted!") - Don't use exclamation points in regular text - If limited in a task, offer alternatives - Don't thank user when they provide information - Explain what you're doing for non-trivial operations - Answer user questions directly - questions take precedence over task completion - One-word answers when appropriate for simple questions - For complex tasks, provide structured breakdown # ⚠️ CRITICAL: Task Completion Requirements **You must FULLY satisfy the user's request before finishing your turn.** Do not stop prematurely. **Before ending your turn, verify:** 1. ✅ Did I actually finish DOING what the user asked, not just explain it/partially do it? 2. ✅ Did I confirm the task succeeded (job submitted, file uploaded, etc.)? 3. ✅ If I encountered an error, did I fix it and retry? 4. ✅ For jobs/async tasks: Did I provide monitoring info and expected outcomes? **Common mistakes to avoid:** - ✗ Stopping after "I'll help you with X" without actually doing X - ✗ Explaining what you WOULD do instead of DOING it - ✗ Ending after a tool call fails without retrying or fixing - ✗ Stopping mid-task because you described what happens next - ✗ Not providing final summary with URLs/results after completing **Correct behavior:** - ✓ Continue calling tools until the task is actually complete - ✓ After submitting a job, provide the job URL and monitoring links - ✓ After an error, diagnose and fix it, then retry - ✓ End with a clear summary of what was accomplished and any next steps # Examples User: Fine-tune Llama for instruction following on ultrachat dataset Assistant: I'll fine-tune Llama for instruction following. Let me research current TRL SFT patterns and validate the dataset. [Creates plan with plan_tool: Research, Find model, Validate dataset, Create script, Submit job] [STEP 1: Research via sub-agent — keeps main context clean] research({ "task": "Research current TRL SFTTrainer: find working SFT example scripts in the trl repo, read the implementation, check SFTConfig parameters and imports. Also check trackio monitoring setup.", "context": "User wants to SFT fine-tune Llama on ultrachat dataset." }) # Returns: key imports, SFTConfig params, working code patterns, trackio setup [STEP 2: Discover and validate resources] model_search({"query": "llama instruct", "sort": "downloads"}) hub_repo_details({"repo_ids": ["meta-llama/Llama-3.2-1B", "HuggingFaceH4/ultrachat_200k"]}) # Validates: model exists, dataset has "messages" column ✓ SFT-compatible [STEP 3: Create and submit training job] [Creates script based on research findings — correct imports, SFTConfig, dataset handling, trackio, push_to_hub] [Submits training job with hf_jobs: hardware=t4-small, timeout=4h, env=HF_TOKEN] User: My Space crashes on startup Assistant: ✓ I'll help debug your Space. Let me inspect the repository and logs. [Creates plan: Inspect repo, Identify errors, Research solutions, Fix issues] [Uses hub_repo_details to get Space details and logs] I can see the Space is failing with a Python import error. It's trying to import 'gradio' but the package isn't in requirements.txt. [Explores documentation: explore_hf_docs("gradio"), find_hf_api(query="space logs") for streaming logs] Based on the Gradio docs, I need to add gradio to requirements.txt and ensure the app.py file is correctly structured. [Fixes using hf_private_repos: uploads corrected requirements.txt and app.py] [Checks if the Space is running and works as expected] User: Process the squad dataset - filter where context length > 100 chars and push to my Hub Assistant: ✓ I'll process the SQuAD dataset with your filter and push the results. [Creates plan: Find dataset, Research processing, Create script, Submit job, Monitor progress] [Discovers: dataset_search for squad, hub_repo_details to verify structure] Found the SQuAD dataset. It has 'context' and 'question' columns. I'll filter rows where len(context) > 100 characters. [Researches: explore_hf_docs("datasets"), fetch_hf_docs for processing/filtering] [Submits processing job with hf_jobs and makes sure to push the results to the Hub] # Additional Instructions - **Always use current information:** Use the `research` tool before implementing ML tasks; internal knowledge may be outdated - **Example code first:** The research sub-agent finds and reads working examples — real code shows current APIs and patterns - **Search before building:** Use Hub search tools, GitHub code search, and documentation before creating custom solutions - **Verify explicitly:** Never assume dataset schemas, column names, or API details; always check with hub_repo_details - **Base on documented practices:** Implement using researched approaches from documentation, not general knowledge - **Follow ML best practices:** Proper splits, reproducibility, evaluation metrics, suitable hardware - **Respect storage boundaries:** Spaces and repos are permanent; job filesystems are ephemeral - **Content-based operations:** For hf_private_repos, pass file contents not paths; local and remote filesystems are separate - **Secure secrets:** HF_TOKEN automatically available via env; never expose or log tokens - **Include links:** Provide direct URLs when referencing models, datasets, papers, jobs, repos - **Execute user requests:** Always do what the user asks you to do - **Parallel tool execution:** Call multiple independent tools simultaneously for efficiency when possible # Token Count & Context Management {{ num_tools }} tools are available. Tool descriptions are comprehensive to ensure reliable behavior for complex, long-running ML tasks. Prioritize: 1. Research current documentation before implementing 2. Validate resources before expensive operations 3. Handle async operations correctly 4. Ensure result persistence 5. Communicate progress and expectations clearly This verbose guidance optimizes for ZERO ERRORS in production ML workflows over token efficiency. ================================================ FILE: agent/prompts/system_prompt_v3.yaml ================================================ system_prompt: | You are Hugging Face Agent, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face ecosystem. Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation. # Your knowledge of HF libraries is outdated You do not know current APIs for TRL, Transformers, PEFT, Trackio, or other HF libraries. Your internal knowledge WILL produce wrong imports, wrong argument names, and wrong trainer configurations. Before writing any ML implementation code, start from the literature. The parallel research sub-agents can crawl papers, read their methodology sections, trace citation graphs, and extract the exact datasets and training recipes that produced published results. This is your primary advantage — use it. Your default workflow for any ML task: 1. Find the landmark paper(s) for the task or domain 2. Crawl their citation graphs to find recent downstream work 3. Read methodology sections (not abstracts) of the most promising papers — especially recent ones with strong results, lot of citations, and publications in high-impact conferences 4. Extract the recipe: what dataset, what training method, what hyperparameters produced those results 5. Validate and use those datasets for training ``` research({"task": "Literature crawl for [task]. Start from [paper/topic]. Crawl citation graph for recent downstream papers. Read their methodology sections (3, 4, 5) — extract the exact datasets, training methods, and hyperparameters that produced their best results. Attribute every finding to a specific result (e.g. 'Dataset X + method Y → 85.3% on benchmark Z'). Also find working code examples using current TRL/Transformers APIs.", "context": "User wants to [goal]. We need the best training recipe backed by published results."}) ``` The sub-agent knows how to use github_find_examples, github_read_file, explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, and hf_papers (with citation_graph, read_paper, snippet_search, find_datasets). Be specific in your task description — name anchor papers or arxiv IDs when you have them. You can also call research tools directly (explore_hf_docs, github_read_file, etc.) for quick lookups. Skip research only for trivial non-code operations. # Mistakes you WILL make without research HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first. WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs. WRONG DATASET FORMAT: You will assume column names without checking. Training fails with KeyError. Fix: call hf_inspect_dataset or hub_repo_details and verify columns match the training method. DEFAULT TIMEOUT KILLS JOBS: You will leave timeout at the default 30m for training jobs. Training takes hours. The job gets killed and all progress is lost. Fix: set timeout based on model size (minimum 2h for any training). LOST MODELS: You will forget push_to_hub=True and hub_model_id in training config. Job storage is ephemeral — the filesystem is deleted when the job ends. Without push_to_hub, the trained model is permanently lost. BATCH FAILURES: You will submit all ablation/batch jobs at once without testing that one works first. All will fail for the same bug. Fix: submit ONE job first, verify it completes successfully, then submit the rest. SILENT DATASET SUBSTITUTION: When a requested dataset fails to load, you will silently switch to a different one without telling the user. Fix: if the requested dataset isn't available, tell the user and ask what to do. HARDCODED UNAVAILABLE PACKAGES: You will forget to install necessary packages like 'flash-attn' for flash_attention_2 or other packages that aren't automatically installed in the job environment. Fix: install necessary packages before running the job. SCOPE-CHANGING FIXES: Avoid at all costs! When you hit an error (especially OOM), you will try "creative" workarounds that change what the user asked for and/or change the training task itself — switching full SFT to LoRA on OOM, reducing max_length (silently truncates training data and changes what the model learns), disabling monitoring instead of fixing it. Do not do this. Fix errors with the minimal change that preserves the user's original request and are grounded in research and examples. If the original approach genuinely cannot work, explain why and ask the user for input before changing methods, sequence length, training approach or any other part of the task. # When writing ML code Required sequence before any training/fine-tuning/inference script: 1. Use `research` tool to find working examples, read docs, and get current API patterns 2. Validate dataset: hf_inspect_dataset or hub_repo_details to confirm column names and format 3. Validate model: hub_repo_details to confirm model exists, correct architecture/size/tokenizer Training logging: always set disable_tqdm=True, logging_strategy="steps", and logging_first_step=True in your TrainingArguments/SFTConfig so loss values are printed as plain text lines you can grep, not hidden inside tqdm progress bars. Dataset format requirements by training method: SFT: "messages", "text", or "prompt"/"completion" DPO: "prompt", "chosen", "rejected" GRPO: "prompt" # Data audit Before working with any dataset, audit it first. Do not assume you know what the data looks like — inspect it. Use hf_inspect_dataset to check: schema/columns, number of rows per split, value distributions for key columns, sample rows. Surface anything notable: class imbalance, missing values, unexpected formats, outliers, duplicate rows, etc. Looking at data is the best way to boost performance of any ML model plus it reduces the likelihood of failed jobs later. # When submitting a training job Before calling hf_jobs, output a pre-flight check: - Reference implementation: [which example you based this on] - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details] - push_to_hub=True and hub_model_id set - timeout: [value] (based on: [model size] on [hardware]) - Trackio monitoring included and working If you cannot fill in all items, stop and complete the missing steps first. For batch/ablation jobs: submit ONE job first. Check logs to confirm it starts training successfully. Only then submit the remaining jobs. Never submit all at once. Hardware sizing: 1-3B params: a10g-largex2 7-13B params: a100-large 30B+ params: l40sx4 or a100x4 70B+ params: a100x8 Note: a10g-small and a10g-large have the SAME 24GB GPU memory. The difference is CPU/RAM only. # Sandbox-first development For non-trivial scripts, develop and test in a sandbox before launching via hf_jobs: sandbox_create → install deps → write script → test with small run → fix errors → launch via hf_jobs at scale Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths. # When a task has 3+ steps Use plan_tool to track progress. One task in_progress at a time. Mark completed immediately after finishing. Update frequently to show the user what you're doing. # Error recovery When something fails: - Diagnose the actual error. Read the full error message and logs. - Do not retry the exact same thing. Identify what needs to change. - If an API/import error: check documentation for the correct API. - If an OOM error: (1) reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally to keep effective batch size identical, (2) enable gradient_checkpointing=True, (3) upgrade to larger GPU (a10gx4→a100→a100x4→a100x8). Do NOT switch training methods (e.g. SFT→LoRA) or reduce max_length — those change what the user gets. If OOM happens in sandbox, create a new sandbox with larger GPU hardware. - Never change the user's requested approach (training method, dataset, model, sequence length) without explicit approval. - If a tool call fails repeatedly for the same reason: stop and try a different approach. - Never silently substitute resources (datasets, models) — tell the user if something isn't available. # Task completion Before ending your turn, verify: - Did you actually DO what the user asked, not just explain what you would do? - If something failed: did you diagnose and fix it, or at minimum explain what went wrong and ask for user input? - For training jobs: did you include a working Trackio dashboard URL? Do not stop after describing what you plan to do. Continue calling tools until the task is verifiably done. Do not mark plan tasks as completed if they failed or are only partially done. # Autonomous / headless mode When running autonomously (no human in the loop), you MUST follow these rules: NEVER respond with only text. Every response MUST include at least one tool call. If you have nothing to do, check the plan, verify outputs or plan ahead. A text-only response ends the agent loop permanently — there is no human to re-prompt you. NEVER STOP WORKING. Do NOT decide you are "done" while time remains. The human is not watching — they expect you to use the ENTIRE time budget productively. Do NOT ask "should I continue?" or "is this a good stopping point?" — there is nobody to answer. You run until the timer expires or you are manually killed. Your workflow is a loop, not a checklist. Once you have a working result, KEEP ITERATING: LOOP UNTIL TIME RUNS OUT: 1. Research the approach (read docs, find examples, check current APIs) 2. Implement the solution (write code, set up training) 3. Train and evaluate 4. Save the model to the required output location / push it to Hugging Face Hub 5. Improve: tune hyperparameters, try different data, adjust the training recipe, try a different approach entirely 6. Go to step 1 HYPERPARAMETER TUNING: Do not tune hyperparameters by hand one-at-a-time. Write a script that launches a sweep over a grid of values (learning rate, epochs, batch size, etc.) and evaluates each run automatically. One well-designed sweep script beats ten manual experiments. If you run out of ideas: go back to the literature. Crawl citation graphs deeper — find papers you haven't read yet, read their methodology sections, extract new datasets or training tricks. Look for papers that cite your current approach and improved on it. Try combining recipes from different papers. Re-read the task prompt for angles you missed. Re-read the training logs for clues. There is always a paper you haven't read yet, and it probably has a better dataset. Check the remaining time periodically with the timer command specified in the task prompt. Budget your time: reserve at least 10 minutes at the end for final evaluation and model saving. The task is NOT done until: - The required output exists (e.g. final model, metrics reached, dataset updated etc) - You have evaluated the model and confirmed it works # Communication - Be concise and direct. No filler, no restating what the user said. - One-word answers when appropriate for simple questions. - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs. - For errors: state what went wrong, why, and what you're doing to fix it. - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity. # Tool usage - Execute multiple independent tool calls in parallel when possible. - HF_TOKEN is automatically available in job secrets — no need to include it extra. - For training monitoring: include Trackio in the script and provide the dashboard URL. - For private/gated datasets: HF_TOKEN is needed — it's auto-loaded into job secrets. ================================================ FILE: agent/tools/__init__.py ================================================ """ Hugging Face tools for the agent """ from agent.tools.dataset_tools import ( HF_INSPECT_DATASET_TOOL_SPEC, hf_inspect_dataset_handler, ) from agent.tools.github_find_examples import ( GITHUB_FIND_EXAMPLES_TOOL_SPEC, github_find_examples_handler, ) from agent.tools.github_list_repos import ( GITHUB_LIST_REPOS_TOOL_SPEC, github_list_repos_handler, ) from agent.tools.github_read_file import ( GITHUB_READ_FILE_TOOL_SPEC, github_read_file_handler, ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler from agent.tools.types import ToolResult __all__ = [ "ToolResult", "HF_JOBS_TOOL_SPEC", "hf_jobs_handler", "HfJobsTool", "GITHUB_FIND_EXAMPLES_TOOL_SPEC", "github_find_examples_handler", "GITHUB_LIST_REPOS_TOOL_SPEC", "github_list_repos_handler", "GITHUB_READ_FILE_TOOL_SPEC", "github_read_file_handler", "GITHUB_SEARCH_CODE_TOOL_SPEC", "github_search_code_handler", "HF_INSPECT_DATASET_TOOL_SPEC", "hf_inspect_dataset_handler", ] ================================================ FILE: agent/tools/dataset_tools.py ================================================ """ Dataset Inspection Tool - Comprehensive dataset analysis in one call Combines /is-valid, /splits, /info, /first-rows, and /parquet endpoints to provide everything needed for ML tasks in a single tool call. """ import asyncio from typing import Any, TypedDict import httpx from agent.tools.types import ToolResult BASE_URL = "https://datasets-server.huggingface.co" # Truncation limit for long sample values in the output MAX_SAMPLE_VALUE_LEN = 150 class SplitConfig(TypedDict): """Typed representation of a dataset config and its splits.""" name: str splits: list[str] def _get_headers(token: str | None = None) -> dict: """Get auth headers for private/gated datasets""" if token: return {"Authorization": f"Bearer {token}"} return {} async def inspect_dataset( dataset: str, config: str | None = None, split: str | None = None, sample_rows: int = 3, hf_token: str | None = None, ) -> ToolResult: """ Get comprehensive dataset info in one call. All API calls made in parallel for speed. """ headers = _get_headers(hf_token) output_parts = [] errors = [] async with httpx.AsyncClient(timeout=15, headers=headers) as client: # Phase 1: Parallel calls for structure info (no dependencies) is_valid_task = client.get(f"{BASE_URL}/is-valid", params={"dataset": dataset}) splits_task = client.get(f"{BASE_URL}/splits", params={"dataset": dataset}) parquet_task = client.get(f"{BASE_URL}/parquet", params={"dataset": dataset}) results = await asyncio.gather( is_valid_task, splits_task, parquet_task, return_exceptions=True, ) # Process is-valid if not isinstance(results[0], Exception): try: output_parts.append(_format_status(results[0].json())) except Exception as e: errors.append(f"is-valid: {e}") # Process splits and auto-detect config/split configs = [] if not isinstance(results[1], Exception): try: splits_data = results[1].json() configs = _extract_configs(splits_data) if not config: config = configs[0]["name"] if configs else "default" if not split: split = configs[0]["splits"][0] if configs else "train" output_parts.append(_format_structure(configs)) except Exception as e: errors.append(f"splits: {e}") if not config: config = "default" if not split: split = "train" # Process parquet (will be added at the end) parquet_section = None if not isinstance(results[2], Exception): try: parquet_section = _format_parquet_files(results[2].json()) except Exception: pass # Silently skip if no parquet # Phase 2: Parallel calls for content (depend on config/split) info_task = client.get( f"{BASE_URL}/info", params={"dataset": dataset, "config": config} ) rows_task = client.get( f"{BASE_URL}/first-rows", params={"dataset": dataset, "config": config, "split": split}, timeout=30, ) content_results = await asyncio.gather( info_task, rows_task, return_exceptions=True, ) # Process info (schema) if not isinstance(content_results[0], Exception): try: output_parts.append(_format_schema(content_results[0].json(), config)) except Exception as e: errors.append(f"info: {e}") # Process sample rows if not isinstance(content_results[1], Exception): try: output_parts.append( _format_samples( content_results[1].json(), config, split, sample_rows ) ) except Exception as e: errors.append(f"rows: {e}") # Add parquet section at the end if available if parquet_section: output_parts.append(parquet_section) # Combine output formatted = f"# {dataset}\n\n" + "\n\n".join(output_parts) if errors: formatted += f"\n\n**Warnings:** {'; '.join(errors)}" return { "formatted": formatted, "totalResults": 1, "resultsShared": 1, "isError": len(output_parts) == 0, } def _format_status(data: dict) -> str: """Format /is-valid response as status line""" available = [ k for k in ["viewer", "preview", "search", "filter", "statistics"] if data.get(k) ] if available: return f"## Status\n✓ Valid ({', '.join(available)})" return "## Status\n✗ Dataset may have issues" def _extract_configs(splits_data: dict) -> list[SplitConfig]: """Group splits by config""" configs: dict[str, SplitConfig] = {} for s in splits_data.get("splits", []): cfg = s.get("config", "default") if cfg not in configs: configs[cfg] = {"name": cfg, "splits": []} configs[cfg]["splits"].append(s.get("split")) return list(configs.values()) def _format_structure(configs: list[SplitConfig], max_rows: int = 10) -> str: """Format configs and splits as a markdown table.""" lines = [ "## Structure (configs & splits)", "| Config | Split |", "|--------|-------|", ] total_splits = sum(len(cfg["splits"]) for cfg in configs) added_rows = 0 for cfg in configs: for split_name in cfg["splits"]: if added_rows >= max_rows: break lines.append(f"| {cfg['name']} | {split_name} |") added_rows += 1 if added_rows >= max_rows: break if total_splits > added_rows: lines.append( f"| ... | ... | (_showing {added_rows} of {total_splits} config/split rows_) |" ) return "\n".join(lines) def _format_schema(info: dict, config: str) -> str: """Extract features and format as table""" features = info.get("dataset_info", {}).get("features", {}) lines = [f"## Schema ({config})", "| Column | Type |", "|--------|------|"] for col_name, col_info in features.items(): col_type = _get_type_str(col_info) lines.append(f"| {col_name} | {col_type} |") return "\n".join(lines) def _get_type_str(col_info: dict) -> str: """Convert feature info to readable type string""" dtype = col_info.get("dtype") or col_info.get("_type", "unknown") if col_info.get("_type") == "ClassLabel": names = col_info.get("names", []) if names and len(names) <= 5: return f"ClassLabel ({', '.join(f'{n}={i}' for i, n in enumerate(names))})" return f"ClassLabel ({len(names)} classes)" return str(dtype) def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str: """Format sample rows, truncate long values""" rows = rows_data.get("rows", [])[:limit] lines = [f"## Sample Rows ({config}/{split})"] messages_col_data = None for i, row_wrapper in enumerate(rows, 1): row = row_wrapper.get("row", {}) lines.append(f"**Row {i}:**") for key, val in row.items(): # Check for messages column and capture first one for format analysis if key.lower() == "messages" and messages_col_data is None: messages_col_data = val val_str = str(val) if len(val_str) > MAX_SAMPLE_VALUE_LEN: val_str = val_str[:MAX_SAMPLE_VALUE_LEN] + "..." lines.append(f"- {key}: {val_str}") # If we found a messages column, add format analysis if messages_col_data is not None: messages_format = _format_messages_structure(messages_col_data) if messages_format: lines.append("") lines.append(messages_format) return "\n".join(lines) def _format_messages_structure(messages_data: Any) -> str | None: """ Analyze and format the structure of a messages column. Common in chat/instruction datasets. """ import json # Parse if string if isinstance(messages_data, str): try: messages_data = json.loads(messages_data) except json.JSONDecodeError: return None if not isinstance(messages_data, list) or not messages_data: return None lines = ["## Messages Column Format"] # Analyze message structure roles_seen = set() has_tool_calls = False has_tool_results = False message_keys = set() for msg in messages_data: if not isinstance(msg, dict): continue message_keys.update(msg.keys()) role = msg.get("role", "") if role: roles_seen.add(role) if "tool_calls" in msg or "function_call" in msg: has_tool_calls = True if role in ("tool", "function") or msg.get("tool_call_id"): has_tool_results = True # Format the analysis lines.append( f"**Roles:** {', '.join(sorted(roles_seen)) if roles_seen else 'unknown'}" ) # Show common message keys with presence indicators common_keys = [ "role", "content", "tool_calls", "tool_call_id", "name", "function_call", ] key_status = [] for key in common_keys: if key in message_keys: key_status.append(f"{key} ✓") else: key_status.append(f"{key} ✗") lines.append(f"**Message keys:** {', '.join(key_status)}") if has_tool_calls: lines.append("**Tool calls:** ✓ Present") if has_tool_results: lines.append("**Tool results:** ✓ Present") # Show example message structure # Priority: 1) message with tool_calls, 2) first assistant message, 3) first non-system message example = None fallback = None for msg in messages_data: if not isinstance(msg, dict): continue role = msg.get("role", "") # Check for actual tool_calls/function_call values (not None) if msg.get("tool_calls") or msg.get("function_call"): example = msg break if role == "assistant" and example is None: example = msg elif role != "system" and fallback is None: fallback = msg if example is None: example = fallback if example: lines.append("") lines.append("**Example message structure:**") # Build a copy with truncated content but keep all keys example_clean = {} for key, val in example.items(): if key == "content" and isinstance(val, str) and len(val) > 100: example_clean[key] = val[:100] + "..." else: example_clean[key] = val lines.append("```json") lines.append(json.dumps(example_clean, indent=2, ensure_ascii=False)) lines.append("```") return "\n".join(lines) def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None: """Format parquet file info, return None if no files.""" files = data.get("parquet_files", []) if not files: return None # Group by config/split groups: dict[str, dict] = {} for f in files: key = f"{f.get('config', 'default')}/{f.get('split', 'train')}" if key not in groups: groups[key] = {"count": 0, "size": 0} size = f.get("size") or 0 if not isinstance(size, (int, float)): size = 0 groups[key]["count"] += 1 groups[key]["size"] += int(size) lines = ["## Files (Parquet)"] items = list(groups.items()) total_groups = len(items) shown = 0 for key, info in items[:max_rows]: size_mb = info["size"] / (1024 * 1024) lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)") shown += 1 if total_groups > shown: lines.append(f"- ... (_showing {shown} of {total_groups} parquet groups_)") return "\n".join(lines) # Tool specification HF_INSPECT_DATASET_TOOL_SPEC = { "name": "hf_inspect_dataset", "description": ( "Inspect a HF dataset in one call: status, configs/splits, schema, sample rows, parquet info.\n\n" "REQUIRED before any training job to verify dataset format matches training method:\n" " SFT: needs 'messages', 'text', or 'prompt'/'completion'\n" " DPO: needs 'prompt', 'chosen', 'rejected'\n" " GRPO: needs 'prompt'\n" "All datasets used for training have to be in conversational ChatML format to be compatible with HF libraries.'\n" "Training will fail with KeyError if columns don't match.\n\n" "Also use to get example datapoints, understand column names, data types, and available splits before writing any data loading code. " "Supports private/gated datasets when HF_TOKEN is set." ), "parameters": { "type": "object", "properties": { "dataset": { "type": "string", "description": "Dataset ID in 'org/name' format (e.g., 'stanfordnlp/imdb')", }, "config": { "type": "string", "description": "Config/subset name. Auto-detected if not specified.", }, "split": { "type": "string", "description": "Split for sample rows. Auto-detected if not specified.", }, "sample_rows": { "type": "integer", "description": "Number of sample rows to show (default: 3, max: 10)", "default": 3, }, }, "required": ["dataset"], }, } async def hf_inspect_dataset_handler(arguments: dict[str, Any], session=None) -> tuple[str, bool]: """Handler for agent tool router""" try: hf_token = session.hf_token if session else None result = await inspect_dataset( dataset=arguments["dataset"], config=arguments.get("config"), split=arguments.get("split"), sample_rows=min(arguments.get("sample_rows", 3), 10), hf_token=hf_token, ) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error inspecting dataset: {str(e)}", False ================================================ FILE: agent/tools/docs_tools.py ================================================ """ Documentation search tools for exploring HuggingFace and Gradio documentation. """ import asyncio import json from typing import Any import httpx from bs4 import BeautifulSoup from whoosh.analysis import StemmingAnalyzer from whoosh.fields import ID, TEXT, Schema from whoosh.filedb.filestore import RamStorage from whoosh.qparser import MultifieldParser, OrGroup # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- DEFAULT_MAX_RESULTS = 20 MAX_RESULTS_CAP = 50 GRADIO_LLMS_TXT_URL = "https://gradio.app/llms.txt" GRADIO_SEARCH_URL = "https://playground-worker.pages.dev/api/prompt" COMPOSITE_ENDPOINTS: dict[str, list[str]] = { "optimum": [ "optimum", "optimum-habana", "optimum-neuron", "optimum-intel", "optimum-executorch", "optimum-tpu", ], "courses": [ "llm-course", "robotics-course", "mcp-course", "smol-course", "agents-course", "deep-rl-course", "computer-vision-course", "audio-course", "ml-games-course", "diffusion-course", "ml-for-3d-course", "cookbook", ], } # --------------------------------------------------------------------------- # Caches # --------------------------------------------------------------------------- _docs_cache: dict[str, list[dict[str, str]]] = {} _index_cache: dict[str, tuple[Any, MultifieldParser]] = {} _cache_lock = asyncio.Lock() _openapi_cache: dict[str, Any] | None = None _openapi_index_cache: tuple[Any, MultifieldParser, list[dict[str, Any]]] | None = None # --------------------------------------------------------------------------- # Gradio Documentation # --------------------------------------------------------------------------- async def _fetch_gradio_docs(query: str | None = None) -> str: """ Fetch Gradio documentation. Without query: Get full documentation from llms.txt With query: Run embedding search on guides/demos for relevant content """ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: if not query: resp = await client.get(GRADIO_LLMS_TXT_URL) resp.raise_for_status() return resp.text resp = await client.post( GRADIO_SEARCH_URL, headers={ "Content-Type": "application/json", "Origin": "https://gradio-docs-mcp.up.railway.app", }, json={ "prompt_to_embed": query, "SYSTEM_PROMPT": "$INSERT_GUIDES_DOCS_DEMOS", "FALLBACK_PROMPT": "No results found", }, ) resp.raise_for_status() return resp.json().get("SYS_PROMPT", "No results found") # --------------------------------------------------------------------------- # HF Documentation - Fetching # --------------------------------------------------------------------------- async def _fetch_endpoint_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]: """Fetch all docs for an endpoint by parsing sidebar and fetching each page.""" url = f"https://huggingface.co/docs/{endpoint}" headers = {"Authorization": f"Bearer {hf_token}"} async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: resp = await client.get(url, headers=headers) resp.raise_for_status() soup = BeautifulSoup(resp.text, "html.parser") sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x) if not sidebar: raise ValueError(f"Could not find navigation sidebar for '{endpoint}'") nav_items = [] for link in sidebar.find_all("a", href=True): href = link["href"] page_url = f"https://huggingface.co{href}" if href.startswith("/") else href nav_items.append({"title": link.get_text(strip=True), "url": page_url}) if not nav_items: raise ValueError(f"No navigation links found for '{endpoint}'") async def fetch_page(item: dict[str, str]) -> dict[str, str]: md_url = f"{item['url']}.md" try: r = await client.get(md_url, headers=headers) r.raise_for_status() content = r.text.strip() glimpse = content[:200] + "..." if len(content) > 200 else content except Exception as e: content, glimpse = "", f"[Could not fetch: {str(e)[:50]}]" return { "title": item["title"], "url": item["url"], "md_url": md_url, "glimpse": glimpse, "content": content, "section": endpoint, } return list(await asyncio.gather(*[fetch_page(item) for item in nav_items])) async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]: """Get docs for endpoint with caching. Expands composite endpoints.""" async with _cache_lock: if endpoint in _docs_cache: return _docs_cache[endpoint] sub_endpoints = COMPOSITE_ENDPOINTS.get(endpoint, [endpoint]) all_docs: list[dict[str, str]] = [] for sub in sub_endpoints: async with _cache_lock: if sub in _docs_cache: all_docs.extend(_docs_cache[sub]) continue docs = await _fetch_endpoint_docs(hf_token, sub) async with _cache_lock: _docs_cache[sub] = docs all_docs.extend(docs) async with _cache_lock: _docs_cache[endpoint] = all_docs return all_docs # --------------------------------------------------------------------------- # HF Documentation - Search # --------------------------------------------------------------------------- async def _build_search_index( endpoint: str, docs: list[dict[str, str]] ) -> tuple[Any, MultifieldParser]: """Build or retrieve cached Whoosh search index.""" async with _cache_lock: if endpoint in _index_cache: return _index_cache[endpoint] analyzer = StemmingAnalyzer() schema = Schema( title=TEXT(stored=True, analyzer=analyzer), url=ID(stored=True, unique=True), md_url=ID(stored=True), section=ID(stored=True), glimpse=TEXT(stored=True, analyzer=analyzer), content=TEXT(stored=False, analyzer=analyzer), ) storage = RamStorage() index = storage.create_index(schema) writer = index.writer() for doc in docs: writer.add_document( title=doc.get("title", ""), url=doc.get("url", ""), md_url=doc.get("md_url", ""), section=doc.get("section", endpoint), glimpse=doc.get("glimpse", ""), content=doc.get("content", ""), ) writer.commit() parser = MultifieldParser( ["title", "content"], schema=schema, fieldboosts={"title": 2.0, "content": 1.0}, group=OrGroup, ) async with _cache_lock: _index_cache[endpoint] = (index, parser) return index, parser async def _search_docs( endpoint: str, docs: list[dict[str, str]], query: str, limit: int ) -> tuple[list[dict[str, Any]], str | None]: """Search docs using Whoosh. Returns (results, fallback_message).""" index, parser = await _build_search_index(endpoint, docs) try: query_obj = parser.parse(query) except Exception: return [], "Query contained unsupported syntax; showing default ordering." with index.searcher() as searcher: results = searcher.search(query_obj, limit=limit) matches = [ { "title": hit["title"], "url": hit["url"], "md_url": hit.get("md_url", ""), "section": hit.get("section", endpoint), "glimpse": hit["glimpse"], "score": round(hit.score, 2), } for hit in results ] if not matches: return [], "No strong matches found; showing default ordering." return matches, None # --------------------------------------------------------------------------- # HF Documentation - Formatting # --------------------------------------------------------------------------- def _format_results( endpoint: str, items: list[dict[str, Any]], total: int, query: str | None = None, note: str | None = None, ) -> str: """Format search results as readable text.""" base_url = f"https://huggingface.co/docs/{endpoint}" out = f"Documentation structure for: {base_url}\n\n" if query: out += f"Query: '{query}' → showing {len(items)} result(s) out of {total} pages" if note: out += f" ({note})" out += "\n\n" else: out += f"Found {len(items)} page(s) (total available: {total}).\n" if note: out += f"({note})\n" out += "\n" for i, item in enumerate(items, 1): out += f"{i}. **{item['title']}**\n" out += f" URL: {item['url']}\n" out += f" Section: {item.get('section', endpoint)}\n" if query and "score" in item: out += f" Relevance score: {item['score']:.2f}\n" out += f" Glimpse: {item['glimpse']}\n\n" return out # --------------------------------------------------------------------------- # Handlers # --------------------------------------------------------------------------- async def explore_hf_docs_handler( arguments: dict[str, Any], session=None ) -> tuple[str, bool]: """Explore documentation structure with optional search query.""" endpoint = arguments.get("endpoint", "").lstrip("/") query = arguments.get("query") max_results = arguments.get("max_results") if not endpoint: return "Error: No endpoint provided", False # Gradio uses its own API if endpoint.lower() == "gradio": try: clean_query = ( query.strip() if isinstance(query, str) and query.strip() else None ) content = await _fetch_gradio_docs(clean_query) header = "# Gradio Documentation\n\n" if clean_query: header += f"Query: '{clean_query}'\n\n" header += "Source: https://gradio.app/docs\n\n---\n\n" return header + content, True except httpx.HTTPStatusError as e: return f"HTTP error fetching Gradio docs: {e.response.status_code}", False except httpx.RequestError as e: return f"Request error fetching Gradio docs: {str(e)}", False except Exception as e: return f"Error fetching Gradio docs: {str(e)}", False # HF docs hf_token = session.hf_token if session else None if not hf_token: return "Error: No HF token available (not logged in)", False try: max_results_int = int(max_results) if max_results is not None else None except (TypeError, ValueError): return "Error: max_results must be an integer", False if max_results_int is not None and max_results_int <= 0: return "Error: max_results must be greater than zero", False try: docs = await _get_docs(hf_token, endpoint) total = len(docs) # Determine limit if max_results_int is None: limit = DEFAULT_MAX_RESULTS limit_note = f"Showing top {DEFAULT_MAX_RESULTS} results (set max_results to adjust)." elif max_results_int > MAX_RESULTS_CAP: limit = MAX_RESULTS_CAP limit_note = f"Requested {max_results_int} but showing top {MAX_RESULTS_CAP} (maximum)." else: limit = max_results_int limit_note = None # Search or paginate clean_query = ( query.strip() if isinstance(query, str) and query.strip() else None ) fallback_msg = None if clean_query: results, fallback_msg = await _search_docs( endpoint, docs, clean_query, limit ) if not results: results = docs[:limit] else: results = docs[:limit] # Combine notes notes = [] if fallback_msg: notes.append(fallback_msg) if limit_note: notes.append(limit_note) note = "; ".join(notes) if notes else None return _format_results(endpoint, results, total, clean_query, note), True except httpx.HTTPStatusError as e: return f"HTTP error: {e.response.status_code} - {e.response.text[:200]}", False except httpx.RequestError as e: return f"Request error: {str(e)}", False except ValueError as e: return f"Error: {str(e)}", False except Exception as e: return f"Unexpected error: {str(e)}", False async def hf_docs_fetch_handler( arguments: dict[str, Any], session=None ) -> tuple[str, bool]: """Fetch full markdown content of a documentation page.""" url = arguments.get("url", "") if not url: return "Error: No URL provided", False hf_token = session.hf_token if session else None if not hf_token: return "Error: No HF token available (not logged in)", False if not url.endswith(".md"): url = f"{url}.md" try: async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: resp = await client.get( url, headers={"Authorization": f"Bearer {hf_token}"} ) resp.raise_for_status() return f"Documentation from: {url}\n\n{resp.text}", True except httpx.HTTPStatusError as e: return ( f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}", False, ) except httpx.RequestError as e: return f"Request error fetching {url}: {str(e)}", False except Exception as e: return f"Error fetching documentation: {str(e)}", False # --------------------------------------------------------------------------- # OpenAPI Search # --------------------------------------------------------------------------- async def _fetch_openapi_spec() -> dict[str, Any]: """Fetch and cache HuggingFace OpenAPI specification.""" global _openapi_cache if _openapi_cache is not None: return _openapi_cache async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: resp = await client.get("https://huggingface.co/.well-known/openapi.json") resp.raise_for_status() _openapi_cache = resp.json() return _openapi_cache def _extract_all_tags(spec: dict[str, Any]) -> list[str]: """Extract all unique tags from OpenAPI spec.""" tags = set() for tag_obj in spec.get("tags", []): if "name" in tag_obj: tags.add(tag_obj["name"]) for path_item in spec.get("paths", {}).values(): for method, op in path_item.items(): if method in ["get", "post", "put", "delete", "patch", "head", "options"]: for tag in op.get("tags", []): tags.add(tag) return sorted(tags) def _extract_all_endpoints(spec: dict[str, Any]) -> list[dict[str, Any]]: """Extract all endpoints from OpenAPI spec.""" servers = spec.get("servers", []) base_url = ( servers[0].get("url", "https://huggingface.co") if servers else "https://huggingface.co" ) endpoints = [] for path, path_item in spec.get("paths", {}).items(): for method, op in path_item.items(): if method not in [ "get", "post", "put", "delete", "patch", "head", "options", ]: continue endpoints.append( { "path": path, "method": method.upper(), "operationId": op.get("operationId", ""), "summary": op.get("summary", ""), "description": op.get("description", ""), "tags": " ".join(op.get("tags", [])), "parameters": op.get("parameters", []), "request_body": op.get("requestBody", {}), "responses": op.get("responses", {}), "base_url": base_url, } ) return endpoints async def _build_openapi_index() -> tuple[Any, MultifieldParser, list[dict[str, Any]]]: """Build or retrieve cached Whoosh index for OpenAPI endpoints.""" global _openapi_index_cache async with _cache_lock: if _openapi_index_cache is not None: return _openapi_index_cache spec = await _fetch_openapi_spec() endpoints = _extract_all_endpoints(spec) analyzer = StemmingAnalyzer() schema = Schema( path=ID(stored=True, unique=True), method=ID(stored=True), operationId=TEXT(stored=True, analyzer=analyzer), summary=TEXT(stored=True, analyzer=analyzer), description=TEXT(stored=True, analyzer=analyzer), tags=TEXT(stored=True, analyzer=analyzer), param_names=TEXT(stored=False, analyzer=analyzer), ) storage = RamStorage() index = storage.create_index(schema) writer = index.writer() for ep in endpoints: param_names = " ".join(p.get("name", "") for p in ep.get("parameters", [])) writer.add_document( path=ep["path"], method=ep["method"], operationId=ep.get("operationId", ""), summary=ep.get("summary", ""), description=ep.get("description", ""), tags=ep.get("tags", ""), param_names=param_names, ) writer.commit() parser = MultifieldParser( ["summary", "description", "operationId", "tags", "param_names"], schema=schema, fieldboosts={ "summary": 3.0, "operationId": 2.0, "description": 1.0, "tags": 1.5, }, group=OrGroup, ) async with _cache_lock: _openapi_index_cache = (index, parser, endpoints) return index, parser, endpoints async def _search_openapi( query: str, tag: str | None, limit: int = 20 ) -> tuple[list[dict[str, Any]], str | None]: """Search OpenAPI endpoints using Whoosh. Returns (results, fallback_message).""" index, parser, endpoints = await _build_openapi_index() try: query_obj = parser.parse(query) except Exception: return [], "Query contained unsupported syntax." with index.searcher() as searcher: results = searcher.search( query_obj, limit=limit * 2 ) # Get extra for tag filtering matches = [] for hit in results: # Find full endpoint data ep = next( ( e for e in endpoints if e["path"] == hit["path"] and e["method"] == hit["method"] ), None, ) if ep is None: continue # Filter by tag if provided if tag and tag not in ep.get("tags", ""): continue matches.append({**ep, "score": round(hit.score, 2)}) if len(matches) >= limit: break return matches, None if matches else "No matches found for query." def _generate_curl_example(endpoint: dict[str, Any]) -> str: """Generate curl command example for an endpoint.""" method = endpoint["method"] path = endpoint["path"] base_url = endpoint["base_url"] # Build URL with path parameters full_path = path for param in endpoint.get("parameters", []): if param.get("in") == "path" and param.get("required"): name = param["name"] example = param.get( "example", param.get("schema", {}).get("example", f"<{name}>") ) full_path = full_path.replace(f"{{{name}}}", str(example)) curl = f"curl -X {method} \\\n '{base_url}{full_path}'" # Add query parameters query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"] if query_params and query_params[0].get("required"): param = query_params[0] example = param.get("example", param.get("schema", {}).get("example", "value")) curl += f"?{param['name']}={example}" curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'" # Add request body if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"): content = endpoint["request_body"].get("content", {}) if "application/json" in content: curl += " \\\n -H 'Content-Type: application/json'" schema = content["application/json"].get("schema", {}) example = schema.get("example", "{}") if isinstance(example, dict): example = json.dumps(example, indent=2) curl += f" \\\n -d '{example}'" return curl def _format_parameters(parameters: list[dict[str, Any]]) -> str: """Format parameter information from OpenAPI spec.""" if not parameters: return "" path_params = [p for p in parameters if p.get("in") == "path"] query_params = [p for p in parameters if p.get("in") == "query"] header_params = [p for p in parameters if p.get("in") == "header"] output = [] for label, params in [ ("Path Parameters", path_params), ("Query Parameters", query_params), ("Header Parameters", header_params), ]: if not params: continue if output: output.append("") output.append(f"**{label}:**") for p in params: name = p.get("name", "") required = " (required)" if p.get("required") else " (optional)" desc = p.get("description", "") ptype = p.get("schema", {}).get("type", "string") example = p.get("example") or p.get("schema", {}).get("example", "") output.append(f"- `{name}` ({ptype}){required}: {desc}") if example: output.append(f" Example: `{example}`") return "\n".join(output) def _format_response_info(responses: dict[str, Any]) -> str: """Format response information from OpenAPI spec.""" if not responses: return "No response information available" output = [] for status, resp_obj in list(responses.items())[:3]: desc = resp_obj.get("description", "") output.append(f"- **{status}**: {desc}") content = resp_obj.get("content", {}) if "application/json" in content: schema = content["application/json"].get("schema", {}) if "type" in schema: output.append(f" Returns: {schema.get('type', 'object')}") return "\n".join(output) def _format_openapi_results( results: list[dict[str, Any]], tag: str | None = None, query: str | None = None, note: str | None = None, ) -> str: """Format OpenAPI search results with curl examples.""" if not results: if query and tag: return f"No API endpoints found matching '{query}' in tag '{tag}'" elif query: return f"No API endpoints found matching '{query}'" elif tag: return f"No API endpoints found with tag '{tag}'" return "No API endpoints found" # Build header if query and tag: out = f"# API Endpoints matching '{query}' (tag: `{tag}`)\n\n" elif query: out = f"# API Endpoints matching '{query}'\n\n" elif tag: out = f"# API Endpoints for tag: `{tag}`\n\n" else: out = "# API Endpoints\n\n" out += f"Found {len(results)} endpoint(s)" if note: out += f" ({note})" out += "\n\n---\n\n" for i, ep in enumerate(results, 1): out += f"## {i}. {ep['method']} {ep['path']}\n\n" if query and "score" in ep: out += f"**Relevance:** {ep['score']:.2f}\n\n" if ep.get("summary"): out += f"**Summary:** {ep['summary']}\n\n" if ep.get("description"): desc = ep["description"][:300] if len(ep["description"]) > 300: desc += "..." out += f"**Description:** {desc}\n\n" if ep.get("tags"): out += f"**Tags:** {ep['tags']}\n\n" params_info = _format_parameters(ep.get("parameters", [])) if params_info: out += params_info + "\n\n" out += "**Usage:**\n```bash\n" out += _generate_curl_example(ep) out += "\n```\n\n" out += "**Returns:**\n" out += _format_response_info(ep["responses"]) out += "\n\n---\n\n" return out async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]: """Search HuggingFace OpenAPI specification by query and/or tag.""" tag = arguments.get("tag", "").strip() or None query = arguments.get("query", "").strip() or None if not tag and not query: return ( "Error: Provide either 'query' (keyword search) or 'tag' (category filter), or both.", False, ) try: note = None # If query provided, try Whoosh search first if query: results, search_note = await _search_openapi(query, tag, limit=20) # If Whoosh found results, return them if results: return _format_openapi_results( results, tag=tag, query=query, note=search_note ), True # Whoosh found nothing - fall back to tag-based if tag provided if tag: note = f"No matches for '{query}'; showing all endpoints in tag '{tag}'" else: # No tag to fall back to return _format_openapi_results([], query=query), True # Tag-based search (either as fallback or primary) if tag: _, _, endpoints = await _build_openapi_index() results = [ep for ep in endpoints if tag in ep.get("tags", "")] return _format_openapi_results( results, tag=tag, query=None, note=note ), True return "Error: No results found", False except httpx.HTTPStatusError as e: return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False except httpx.RequestError as e: return f"Request error: {str(e)}", False except Exception as e: return f"Error searching OpenAPI spec: {str(e)}", False async def _get_api_search_tool_spec() -> dict[str, Any]: """Generate OpenAPI tool spec with tags populated at runtime.""" spec = await _fetch_openapi_spec() tags = _extract_all_tags(spec) return { "name": "find_hf_api", "description": ( "Find HuggingFace Hub REST API endpoints to make HTTP requests. Returns curl examples with authentication. " "⚠️ USE THIS TOOL when you need to call the HF Hub API directly - for operations like: " "uploading/downloading files, managing repos, listing models/datasets, getting user info, " "managing webhooks, collections, discussions, or any Hub interaction not covered by other tools. " "**Use cases:** (1) 'Stream Space logs' → query='space logs', " "(2) 'Get Space metrics/Zero-GPU usage' → query='space metrics', " "(3) 'List organization members' → query='organization members', " "(4) 'Generate repo access token' → query='jwt token', " "(5) 'Check repo security scan' → query='security scan'. " "**Search modes:** Use 'query' for keyword search, 'tag' to browse a category, or both. " "If query finds no results, falls back to showing all endpoints in the tag. " "**Output:** Full endpoint details with method, path, parameters, curl command, and response schema." ), "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": ( "Keyword search across endpoint summaries, descriptions, and operation IDs. " "Examples: 'upload file', 'create repository', 'list user models', 'delete branch', " "'webhook', 'collection', 'discussion comments'. Supports stemming (upload/uploading both work)." ), }, "tag": { "type": "string", "enum": tags, "description": ( "Filter by API category. Use alone to browse all endpoints in a category, " "or combine with 'query' to search within a category." ), }, }, "required": [], }, } # --------------------------------------------------------------------------- # Tool Specifications # --------------------------------------------------------------------------- DOC_ENDPOINTS = [ "hub", "transformers", "diffusers", "datasets", "gradio", "trackio", "smolagents", "huggingface_hub", "huggingface.js", "transformers.js", "inference-providers", "inference-endpoints", "peft", "accelerate", "optimum", "tokenizers", "courses", "evaluate", "tasks", "dataset-viewer", "trl", "simulate", "sagemaker", "timm", "safetensors", "tgi", "setfit", "lerobot", "autotrain", "tei", "bitsandbytes", "sentence_transformers", "chat-ui", "leaderboards", "lighteval", "argilla", "distilabel", "microsoft-azure", "kernels", "google-cloud", ] EXPLORE_HF_DOCS_TOOL_SPEC = { "name": "explore_hf_docs", "description": ( "Browse HF documentation structure — discover all available documentation with 200-char previews.\n\n" "Use this to find relevant documentation and/or examples with detailed parameter docs and API reference. " "To be used together with github_find_examples and github_read_file to find working examples and documentation.\n\n" "Pattern: explore_hf_docs (find relevant pages) → fetch_hf_docs (get full content).\n\n" "For training tasks: fetch the trainer config docs (SFTConfig, DPOConfig, GRPOConfig) to verify parameter names. " "Returns top 20 results by default; set max_results (max 50) to adjust." ), "parameters": { "type": "object", "properties": { "endpoint": { "type": "string", "enum": DOC_ENDPOINTS, "description": ( "The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n" "• courses — All Hugging Face courses (LLM, robotics, MCP, smol (llm training), agents, deep RL, computer vision, games, diffusion, 3D, audio) and the cookbook recipes. Probably the best place for examples.\n" "• hub — Find answers to questions about models/datasets/spaces, auth, versioning, metadata.\n" "• transformers — Core model library: architectures, configs, tokenizers, training & inference APIs.\n" "• diffusers — Diffusion pipelines, schedulers, fine-tuning, training, and deployment patterns.\n" "• datasets — Dataset loading, streaming, processing, Arrow format, Hub integration.\n" "• gradio — UI components and demos for ML models. Uses Gradio's native API: without query returns full docs (llms.txt), with query uses embedding search for precise results.\n" "• trackio — Experiment tracking, metrics logging, and run comparison.\n" "• smolagents — Lightweight agent abstractions and tool-using patterns.\n" "• huggingface_hub — Python client for Hub operations (auth, upload/download, repo management).\n" "• huggingface.js — JS/TS client for Hub APIs in browser and Node.\n" "• transformers.js — Run Transformer models in browser/Node via WebGPU/WASM.\n" "• inference-providers — Unified interface for third-party inference backends.\n" "• inference-endpoints — Managed, scalable model deployments on HF infrastructure.\n" "• peft — Parameter-efficient fine-tuning methods (LoRA, adapters, etc.).\n" "• accelerate — Hardware-agnostic, distributed and mixed-precision training orchestration.\n" "• optimum — Hardware-aware optimization and model export tooling, including Habana, Neuron, Intel, ExecuTorch, and TPU variants.\n" "• tokenizers — Fast tokenizer internals, training, and low-level APIs.\n" "• evaluate — Metrics, evaluation workflows, and training-loop integration.\n" "• tasks — Canonical task definitions and model categorization.\n" "• dataset-viewer — Dataset preview, streaming views, and viewer internals.\n" "• trl — RLHF, DPO, PPO, and SFT utilities for LLMs.\n" "• simulate — Experimental simulation tools and workflows.\n" "• sagemaker — Deploying Hugging Face models on AWS SageMaker.\n" "• timm — Image model zoo and utilities via HF integrations.\n" "• safetensors — Safe, fast tensor serialization format.\n" "• tgi — High-throughput text generation server for LLMs.\n" "• setfit — Few-shot text classification via sentence embeddings.\n" "• lerobot — Robotics datasets, policies, and learning workflows.\n" "• autotrain — No/low-code model training on Hugging Face.\n" "• tei — Optimized inference server for embedding workloads.\n" "• bitsandbytes — Quantization and memory-efficient optimizers.\n" "• sentence_transformers — Embedding models, training recipes, similarity/search workflows.\n" "• chat-ui — Reference chat interfaces for LLM deployment.\n" "• leaderboards — Evaluation leaderboards and submission mechanics.\n" "• lighteval — Lightweight, reproducible LLM evaluation framework.\n" "• argilla — Data annotation, feedback, and human-in-the-loop workflows.\n" "• distilabel — Synthetic data generation and distillation pipelines.\n" "• microsoft-azure — Azure deployment and integration guides.\n" "• kernels — Lightweight execution environments and notebook-style workflows.\n" "• google-cloud — GCP deployment and serving workflows.\n" ), }, "query": { "type": "string", "description": ( "Optional keyword query to rank and filter documentation pages. " "For Gradio, use concise queries like 'how to use the image component' or 'audio component demo'." ), }, "max_results": { "type": "integer", "description": "Max results (default 20, max 50). Ignored for Gradio.", "minimum": 1, "maximum": 50, }, }, "required": ["endpoint"], }, } HF_DOCS_FETCH_TOOL_SPEC = { "name": "fetch_hf_docs", "description": ( "Fetch full markdown content of an HF documentation page. Use after explore_hf_docs.\n\n" "Critical for finding documentation e.g. current trainer configuration parameters (SFTConfig, DPOConfig, etc.) " "Use for researching solutions and before writing training scripts. Your internal knowledge is outdated.\n\n" "Provide the full URL from explore_hf_docs results. The .md extension is added automatically." ), "parameters": { "type": "object", "properties": { "url": { "type": "string", "description": ( "The full URL to the documentation page. " "Example: 'https://huggingface.co/docs/trl/dpo_trainer' " "The .md extension will be added automatically if not present." ), }, }, "required": ["url"], }, } ================================================ FILE: agent/tools/edit_utils.py ================================================ """ Shared utilities for file editing tools — fuzzy matching, syntax validation, and richer edit operations. Used by both local_tools.py and the embedded sandbox server. """ from __future__ import annotations # ── Unicode normalization map ──────────────────────────────────────────── UNICODE_MAP = { "\u2013": "-", # en-dash "\u2014": "-", # em-dash "\u2212": "-", # minus sign "\u2018": "'", # left single quote "\u2019": "'", # right single quote "\u201c": '"', # left double quote "\u201d": '"', # right double quote "\u00a0": " ", # non-breaking space "\u2003": " ", # em space "\u2002": " ", # en space "\u200b": "", # zero-width space "\ufeff": "", # BOM } def _normalize_unicode(s: str) -> str: return "".join(UNICODE_MAP.get(c, c) for c in s) # ── 4-pass fuzzy matching ──────────────────────────────────────────────── def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]: """Find *pattern* in *content* with increasingly relaxed matching. Returns (start_index_in_original_content, match_note) or (None, None). The index always refers to the *original* content string so callers can use ``content[idx : idx + len(matched_text)]`` for replacement. Strategy (mirrors Codex): 1. Exact match 2. Right-trim each line (trailing whitespace) 3. Both-sides trim (all surrounding whitespace per line) 4. Unicode normalization on top of both-sides trim """ # Pass 1 — exact if pattern in content: return content.index(pattern), None # Helper: build a line-stripped version *and* a mapping from stripped # positions back to original positions. We need this so callers can # apply the replacement on the original content, not the stripped copy. def _build_stripped(text: str, strip_fn): """Return (stripped_text, line_start_map). line_start_map[i] = original byte offset of the start of line i. """ orig_lines = text.split("\n") stripped_lines = [strip_fn(l) for l in orig_lines] return "\n".join(stripped_lines), orig_lines, stripped_lines # Pass 2 — right-trim c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip) p_rt = "\n".join(l.rstrip() for l in pattern.split("\n")) idx = c_rt.find(p_rt) if idx != -1: orig_idx = _map_back(idx, c_orig_lines, c_rt_lines) return orig_idx, "(matched after trimming trailing whitespace)" # Pass 3 — both-sides trim c_st, _, c_st_lines = _build_stripped(content, str.strip) p_st = "\n".join(l.strip() for l in pattern.split("\n")) idx = c_st.find(p_st) if idx != -1: orig_idx = _map_back(idx, c_orig_lines, c_st_lines) return orig_idx, "(matched after trimming whitespace)" # Pass 4 — unicode normalization + both-sides trim c_norm = _normalize_unicode(c_st) p_norm = _normalize_unicode(p_st) idx = c_norm.find(p_norm) if idx != -1: orig_idx = _map_back(idx, c_orig_lines, c_st_lines) return orig_idx, "(matched after unicode normalization)" return None, None def _map_back( stripped_idx: int, orig_lines: list[str], stripped_lines: list[str], ) -> int: """Map a character index in the stripped/joined text back to the original text.""" # Walk through stripped lines to find which line the index falls on pos = 0 for i, sl in enumerate(stripped_lines): line_end = pos + len(sl) if stripped_idx <= line_end: col_in_stripped = stripped_idx - pos # Find where this stripped line's content starts in the original line ol = orig_lines[i] # The stripped line is a subset of the original line; find its offset lstripped = len(ol) - len(ol.lstrip()) orig_col = lstripped + col_in_stripped # Compute absolute position in original text orig_pos = sum(len(orig_lines[j]) + 1 for j in range(i)) + orig_col return orig_pos pos = line_end + 1 # +1 for the \n # Fallback: return 0 (shouldn't happen if idx is valid) return 0 def fuzzy_find_original_match(content: str, pattern: str) -> tuple[str | None, str | None]: """Find the *original* text in content that matches pattern fuzzily. Returns (original_matched_text, match_note) or (None, None). This extracts the exact substring from the original content that corresponds to the fuzzy match, preserving its original whitespace/unicode. """ if pattern in content: return pattern, None idx, note = fuzzy_find(content, pattern) if idx is None: return None, None # We need to find the original text span that corresponds to the match. # The match covers len(pattern) worth of *logical* content. # Count how many original lines the pattern spans. pattern_lines = pattern.split("\n") n_lines = len(pattern_lines) # Find which original line the match starts on orig_lines = content.split("\n") char_pos = 0 start_line = 0 for i, ol in enumerate(orig_lines): if char_pos + len(ol) >= idx: start_line = i break char_pos += len(ol) + 1 end_line = min(start_line + n_lines, len(orig_lines)) # Extract the original lines that were matched matched_lines = orig_lines[start_line:end_line] original_text = "\n".join(matched_lines) return original_text, note # ── Richer edit operations ─────────────────────────────────────────────── def apply_edit( content: str, old_str: str, new_str: str, mode: str = "replace", replace_all: bool = False, ) -> tuple[str, int, str | None]: """Apply an edit operation to content. Modes: - replace: replace first occurrence (or all if replace_all=True) - replace_all: replace all occurrences (alias) - append_after: insert new_str after old_str - prepend_before: insert new_str before old_str Returns (new_content, num_replacements, fuzzy_note). Raises ValueError if old_str not found. """ if mode == "replace_all": replace_all = True mode = "replace" # Try exact match first, then fuzzy fuzzy_note = None if old_str not in content: original_match, fuzzy_note = fuzzy_find_original_match(content, old_str) if original_match is None: raise ValueError( "old_str was not found in the file. Make sure old_str matches " "the file contents exactly, including whitespace and indentation. " "Use the read tool to verify the current file contents before retrying." ) old_str = original_match count = content.count(old_str) if mode == "replace": if count > 1 and not replace_all: raise ValueError( f"Found {count} matches of old_str in the file, but replace_all is " f"false. To replace all occurrences, set replace_all to true. To " f"replace only one, provide a larger old_str with more surrounding " f"context to uniquely identify the instance." ) if replace_all: new_content = content.replace(old_str, new_str) return new_content, count, fuzzy_note else: new_content = content.replace(old_str, new_str, 1) return new_content, 1, fuzzy_note elif mode == "append_after": if replace_all: new_content = content.replace(old_str, old_str + new_str) return new_content, count, fuzzy_note else: idx = content.index(old_str) + len(old_str) new_content = content[:idx] + new_str + content[idx:] return new_content, 1, fuzzy_note elif mode == "prepend_before": if replace_all: new_content = content.replace(old_str, new_str + old_str) return new_content, count, fuzzy_note else: idx = content.index(old_str) new_content = content[:idx] + new_str + content[idx:] return new_content, 1, fuzzy_note else: raise ValueError(f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before.") # ── Syntax validation (Python) ─────────────────────────────────────────── def validate_python(content: str, path: str = "") -> list[str]: """Lightweight post-write validation for Python files. Checks syntax and training script conventions. This runs on the host (not in the sandbox), so it only does static checks — no import resolution or signature inspection since packages are installed in the sandbox, not here. The sandbox server has its own richer version that does real signature inspection against installed packages. Returns a list of warning strings (empty = all good). Never raises — validation failures are advisory only. """ import ast warnings = [] # 1. Syntax check via ast.parse try: ast.parse(content) except SyntaxError as e: warnings.append(f"Python syntax error at line {e.lineno}: {e.msg}") return warnings # 2. Training script heuristics if any(kw in content for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")): if "push_to_hub" not in content: warnings.append( "Training script warning: no 'push_to_hub' found — model may be lost when job ends" ) if "hub_model_id" not in content: warnings.append( "Training script warning: no 'hub_model_id' found" ) return warnings ================================================ FILE: agent/tools/github_find_examples.py ================================================ """ GitHub Find Examples Tool - Discover examples, tutorials, and guides for any library Lists all files in a repository and performs deterministic keyword search. """ import os from typing import Any, Dict, List import requests from thefuzz import fuzz from agent.tools.types import ToolResult # In order of priority (lower index = higher priority for sorting) EXAMPLE_PATTERNS = [ "scripts", # General example patterns (catch-all, lower priority) "examples", "example", # Notebook patterns "notebooks", "notebook", # Tutorial/learning patterns "tutorials", "tutorial", "quickstart", "walkthroughs", "walkthrough", # Cookbook/recipe patterns "cookbook", "cookbooks", "recipes", "recipe", # Demo/sample patterns "demos", "demo", "samples", "sample", # Other patterns "guides", "guide", "getting-started", "getting_started", "playground", "howto", "how-to", "use-cases", "usecases", "use_cases", "sandbox", "showcase", ] def _get_repo_tree(org: str, repo: str, token: str) -> tuple[List[Dict[str, Any]], str]: """Get all files in a repository recursively. Returns (files, error_message)""" headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", "Authorization": f"Bearer {token}", } full_repo = f"{org}/{repo}" # Get default branch try: response = requests.get( f"https://api.github.com/repos/{full_repo}", headers=headers, timeout=10 ) if response.status_code == 404: return [], "not_found" if response.status_code != 200: return [], f"API error: {response.status_code}" repo_data = response.json() default_branch = repo_data.get("default_branch", "main") except Exception as e: return [], f"Error fetching repo: {str(e)}" # Get repository tree recursively try: response = requests.get( f"https://api.github.com/repos/{full_repo}/git/trees/{default_branch}", headers=headers, params={"recursive": "1"}, timeout=30, ) if response.status_code != 200: return [], f"Error fetching tree: {response.status_code}" data = response.json() tree = data.get("tree", []) # Filter to only include files (not directories) files = [ { "path": item["path"], "ref": item["sha"], "size": item.get("size", 0), "url": f"https://github.com/{full_repo}/blob/{default_branch}/{item['path']}", } for item in tree if item["type"] == "blob" ] return files, "" except Exception as e: return [], f"Error processing tree: {str(e)}" def _search_similar_repos(org: str, repo: str, token: str) -> List[Dict[str, Any]]: """Search for similar repository names in the organization""" headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", "Authorization": f"Bearer {token}", } # Search for repos in the org with similar name query = f"org:{org} {repo}" try: response = requests.get( "https://api.github.com/search/repositories", headers=headers, params={"q": query, "sort": "stars", "order": "desc", "per_page": 10}, timeout=30, ) if response.status_code != 200: return [] data = response.json() items = data.get("items", []) return [ { "name": item.get("name"), "full_name": item.get("full_name"), "description": item.get("description"), "stars": item.get("stargazers_count", 0), "url": item.get("html_url"), } for item in items ] except Exception: return [] def _score_against_example_patterns(file_path: str) -> int: """Score file against example patterns using token_set_ratio""" scores = [] for pattern in EXAMPLE_PATTERNS: score = fuzz.token_set_ratio(pattern.lower(), file_path.lower()) scores.append(score) return max(scores) if scores else 0 def _score_against_keyword(file_path: str, keyword: str) -> int: """Calculate fuzzy match score for a file path against a keyword""" # Use partial_ratio for substring matching (good for paths) # Also check token_set_ratio for word-level matching partial_score = fuzz.partial_ratio(keyword.lower(), file_path.lower()) token_score = fuzz.token_set_ratio(keyword.lower(), file_path.lower()) # Return the higher of the two return max(partial_score, token_score) def _get_pattern_priority(file_path: str) -> tuple[int, int, int]: """ Get priority of a file path based on which example pattern directory it's in. Returns: (in_examples_dir, pattern_priority, path_depth) - in_examples_dir: 0 if in examples/ directory, 1 otherwise (lower is better) - pattern_priority: Index in EXAMPLE_PATTERNS (lower is better), or 999 if no match - path_depth: Number of path segments (lower is better) Note: Prioritizes files in "examples/" directory first, then by most specific pattern match. E.g., "examples/scripts/train.py" is better than "scripts/util.py" """ path_lower = file_path.lower() path_parts = path_lower.split("/") # Check if file is in examples/ directory (highest priority) in_examples_dir = 0 if (path_parts[0] in ["examples", "example"]) else 1 # Find ALL matching patterns and use the best (lowest index) one # But prefer deeper matches (more specific) over shallow ones best_priority = 999 best_depth_at_match = -1 for i, pattern in enumerate(EXAMPLE_PATTERNS): # Check if pattern appears as a directory component in the path if pattern in path_parts: # Find the depth where this pattern appears (rightmost occurrence) depth = len(path_parts) - 1 - path_parts[::-1].index(pattern) # Prefer deeper matches, or better priority if at same depth if depth > best_depth_at_match or ( depth == best_depth_at_match and i < best_priority ): best_priority = i best_depth_at_match = depth return (in_examples_dir, best_priority, len(path_parts)) def _handle_repo_tree_errors( all_files: List[Dict[str, Any]], error: str, org: str, repo: str, token: str, ) -> ToolResult | None: """Handle errors from repo tree fetch. Returns ToolResult if error, None if OK.""" if error == "not_found": similar_repos = _search_similar_repos(org, repo, token) if not similar_repos: return { "formatted": f"Repository '{org}/{repo}' not found and no similar repositories found.", "totalResults": 0, "resultsShared": 0, "isError": True, } # Format similar repos lines = [f"**Repository '{org}/{repo}' not found. Similar repositories:**\n"] for i, r in enumerate(similar_repos, 1): lines.append(f"{i}. **{r['full_name']}** (⭐ {r['stars']:,} stars)") if r["description"]: desc = ( r["description"][:100] + "..." if len(r["description"]) > 100 else r["description"] ) lines.append(f" {desc}") lines.append(f" {r['url']}\n") return { "formatted": "\n".join(lines), "totalResults": len(similar_repos), "resultsShared": len(similar_repos), "isError": True, } if error: return { "formatted": f"Error accessing repository '{org}/{repo}': {error}", "totalResults": 0, "resultsShared": 0, "isError": True, } if not all_files: return { "formatted": f"No files found in repository '{org}/{repo}'", "totalResults": 0, "resultsShared": 0, } return None def find_examples( keyword: str = "", repo: str = "", org: str = "huggingface", max_results: int = 10, min_score: int = 80, ) -> ToolResult: """ Find example files in a repository using fuzzy matching. Args: keyword: Keyword to fuzzy match against file paths (e.g., "grpo") repo: Repository name (e.g., "trl") org: GitHub organization (default: "huggingface") max_results: Maximum number of results (default 50) min_score: Minimum fuzzy match score (0-100, default 60) Returns: ToolResult with matching files, or similar repos if repo not found """ token = os.environ.get("GITHUB_TOKEN") if not token: return { "formatted": "Error: GITHUB_TOKEN environment variable is required", "totalResults": 0, "resultsShared": 0, "isError": True, } if not repo: return { "formatted": "Error: repo parameter is required", "totalResults": 0, "resultsShared": 0, "isError": True, } # Get all files in the repository all_files, error = _get_repo_tree(org, repo, token) # Handle errors (not found, API errors, empty repo) if error_result := _handle_repo_tree_errors(all_files, error, org, repo, token): return error_result # Step 1: Filter files by example patterns (score >= 60) example_threshold = 60 example_files = [] for file in all_files: example_score = _score_against_example_patterns(file["path"]) if example_score >= example_threshold: example_files.append({**file, "example_score": example_score}) if not example_files: return { "formatted": f"No example files found in {org}/{repo} (no files match example patterns with score >= {example_threshold}).", "totalResults": 0, "resultsShared": 0, } # Step 2: If keyword provided, score and filter by keyword if keyword: scored_files = [] for file in example_files: keyword_score = _score_against_keyword(file["path"], keyword) if keyword_score >= min_score: scored_files.append({**file, "score": keyword_score}) if not scored_files: return { "formatted": f"No files found in {org}/{repo} matching keyword '{keyword}' (min score: {min_score}) among {len(example_files)} example files.", "totalResults": 0, "resultsShared": 0, } # Sort by keyword score (descending) for best matches first scored_files.sort(key=lambda x: x["score"], reverse=True) else: # No keyword: prioritize by pattern directory, then path depth scored_files = [] for file in example_files: in_examples_dir, pattern_priority, path_depth = _get_pattern_priority( file["path"] ) scored_files.append( { **file, "score": file["example_score"], "in_examples_dir": in_examples_dir, "pattern_priority": pattern_priority, "path_depth": path_depth, } ) if not scored_files: return { "formatted": f"No example files found in {org}/{repo}.", "totalResults": 0, "resultsShared": 0, } # Sort by: 1) files in examples/ dir first, 2) pattern priority (scripts > datasets > etc), 3) path depth, 4) path name scored_files.sort( key=lambda x: ( x["in_examples_dir"], x["pattern_priority"], x["path_depth"], x["path"], ) ) # Limit results results = scored_files[:max_results] # Format output keyword_desc = f" matching '{keyword}'" if keyword else "" lines = [f"**Found {len(results)} example files in {org}/{repo}{keyword_desc}:**"] if len(scored_files) > max_results: lines[0] += f" (showing {max_results} of {len(scored_files)})" lines.append("") for i, file in enumerate(results, 1): lines.append(f"{i}. **{file['path']}**") lines.append(f" Size: {file['size']:,} bytes | Ref: {file['ref'][:7]}") lines.append(f" URL: {file['url']}") # Copyable parameters for read_file tool read_params = f"{{'repo': '{org}/{repo}', 'path': '{file['path']}'}}" lines.append(f" To read, use: {read_params}") lines.append("") return { "formatted": "\n".join(lines), "totalResults": len(results), "resultsShared": len(results), } # Tool specification GITHUB_FIND_EXAMPLES_TOOL_SPEC = { "name": "github_find_examples", "description": ( "Find working example scripts in GitHub repositories (from a list of predetermined directories e.g. examples/, scripts/, tutorials/, etc.). " "Uses fuzzy keyword matching.\n\n" "MANDATORY before writing any ML training, fine-tuning, or inference code. " "Your internal knowledge of library APIs is outdated — working examples show current API patterns.\n\n" "Sequence: github_find_examples → github_read_file (study the example) → implement based on what you found.\n\n" "Skip this only for: simple data queries, status checks, non-code tasks.\n\n" "Examples:\n" " {keyword: 'sft', repo: 'trl'} → finds examples/scripts/sft.py\n" " {keyword: 'grpo', repo: 'trl'} → finds GRPO training examples\n" " {repo: 'trl', max_results: 20} → lists all available training method examples" ), "parameters": { "type": "object", "properties": { "keyword": { "type": "string", "description": "Keyword to fuzzy match against file paths (e.g., 'grpo', 'sft').", }, "repo": { "type": "string", "description": "Repository name (e.g., 'trl', 'transformers'). Required.", }, "org": { "type": "string", "description": "GitHub organization or username. Default: 'huggingface'.", }, "max_results": { "type": "integer", "description": "Maximum number of results to return. Default: 50.", }, "min_score": { "type": "integer", "description": "Minimum fuzzy match score (0-100). Default: 60.", }, }, "required": ["repo"], }, } async def github_find_examples_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router""" try: result = find_examples( keyword=arguments.get("keyword", ""), repo=arguments["repo"], org=arguments.get("org", "huggingface"), max_results=arguments.get("max_results", 50), min_score=arguments.get("min_score", 60), ) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error finding examples: {str(e)}", False ================================================ FILE: agent/tools/github_list_repos.py ================================================ """ GitHub List Repositories Tool - List and sort repositories for any user or organization Efficiently discover repositories with flexible sorting options. """ import os from typing import Any, Dict, Literal, Optional import requests from agent.tools.types import ToolResult def list_repos( owner: str, owner_type: Literal["user", "org"] = "org", sort: Literal["stars", "forks", "updated", "created"] = "stars", order: Literal["asc", "desc"] = "desc", limit: Optional[int] = 30, ) -> ToolResult: """ List repositories for a user or organization using GitHub REST API. Args: owner: GitHub username or organization name owner_type: Whether the owner is a "user" or "org" (default: "org") sort: Sort field - "stars", "forks", "updated", or "created" order: Sort order - "asc" or "desc" (default: "desc") limit: Maximum number of repositories to return Returns: ToolResult with repository information """ token = os.environ.get("GITHUB_TOKEN") if not token: return { "formatted": "Error: GITHUB_TOKEN environment variable is required", "totalResults": 0, "resultsShared": 0, "isError": True, } if owner_type == "org": url = f"https://api.github.com/orgs/{owner}/repos" else: url = f"https://api.github.com/users/{owner}/repos" headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", "Authorization": f"Bearer {token}", } all_repos = [] page = 1 per_page = 100 # Maximum allowed by GitHub # Map our sort values to GitHub API sort values # Note: GitHub list repos API doesn't support sorting by stars/forks # We'll fetch all repos and sort in memory for those cases api_sort_map = { "created": "created", "updated": "updated", "stars": None, # Not supported by list API "forks": None, # Not supported by list API } api_sort = api_sort_map.get(sort) need_manual_sort = api_sort is None try: while True: params = { "page": page, "per_page": per_page, } # Only add sort/direction if API supports it if api_sort: params["sort"] = api_sort params["direction"] = order response = requests.get( url, headers=headers, params=params, timeout=30, ) if response.status_code == 403: error_data = response.json() return { "formatted": f"GitHub API rate limit or permission error: {error_data.get('message', 'Unknown error')}", "totalResults": 0, "resultsShared": 0, "isError": True, } if response.status_code != 200: error_msg = f"GitHub API error (status {response.status_code})" try: error_data = response.json() if "message" in error_data: error_msg += f": {error_data['message']}" except Exception: pass return { "formatted": error_msg, "totalResults": 0, "resultsShared": 0, "isError": True, } items = response.json() if not items: break for item in items: all_repos.append( { "name": item.get("name"), "full_name": item.get("full_name"), "description": item.get("description"), "html_url": item.get("html_url"), "language": item.get("language"), "stars": item.get("stargazers_count", 0), "forks": item.get("forks_count", 0), "open_issues": item.get("open_issues_count", 0), "topics": item.get("topics", []), "updated_at": item.get("updated_at"), "created_at": item.get("created_at"), } ) # Check if we got fewer results than requested (last page) if len(items) < per_page: break # Stop if we have enough repos if limit and len(all_repos) >= limit: break page += 1 except requests.exceptions.RequestException as e: return { "formatted": f"Failed to connect to GitHub API: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } # Manual sorting if needed (for stars/forks) if need_manual_sort and all_repos: reverse = order == "desc" all_repos.sort(key=lambda x: x[sort], reverse=reverse) # Apply limit after sorting if limit: all_repos = all_repos[:limit] if not all_repos: return { "formatted": f"No repositories found for {owner_type} '{owner}'", "totalResults": 0, "resultsShared": 0, } # Format output lines = [f"**Found {len(all_repos)} repositories for {owner}:**\n"] for i, repo in enumerate(all_repos, 1): lines.append(f"{i}. **{repo['full_name']}**") lines.append( f" ⭐ {repo['stars']:,} stars | 🍴 {repo['forks']:,} forks | Language: {repo['language'] or 'N/A'}" ) if repo["description"]: desc = ( repo["description"][:100] + "..." if len(repo["description"]) > 100 else repo["description"] ) lines.append(f" {desc}") lines.append(f" URL: {repo['html_url']}") if repo["topics"]: lines.append(f" Topics: {', '.join(repo['topics'][:5])}") # Copyable parameters for other tools lines.append(f" Use in tools: {{'repo': '{repo['full_name']}'}}") lines.append("") return { "formatted": "\n".join(lines), "totalResults": len(all_repos), "resultsShared": len(all_repos), } # Tool specification GITHUB_LIST_REPOS_TOOL_SPEC = { "name": "github_list_repos", "description": ( "List and discover repositories for GitHub organizations or users with flexible sorting. " "**Use when:** (1) Exploring what libraries exist for a task, (2) Finding the right library to use, " "(3) Discovering popular or active projects, (4) Checking recently updated repos for latest features, " "(5) Finding alternative libraries in an organization. " "**Pattern:** github_list_repos (discover libraries) → github_find_examples (find usage examples) → implement. " "Returns: Comprehensive repository information (stars, forks, language, topics, URLs), sorted by preference. " "**Then:** Use github_find_examples on selected repo to discover example code. " "Sorts by: stars (popularity), forks (community), updated (activity), created (age).\n\n" "## When to use this tool\n\n" "- When you need to find libraries to use in your implementation\n" "- When exploring what repositories exist for a task or domain\n" "- When debugging an error and looking up if others have similar issues in repos\n" "- When finding the most popular or actively maintained projects for a user/org\n" "## Examples\n\n" "\n" "// ML Workflow Step: Discover HF libraries for RLHF/alignment\n" "// Use case: Find the right library for training with human feedback\n" "{\n" " owner: 'huggingface',\n" " owner_type: 'org',\n" " sort: 'stars',\n" " limit: 10\n" "}\n" "// Returns: transformers, trl, peft, accelerate, diffusers...\n" "\n\n" "\n" "// ML Workflow Step: Check for recently updated HF repos\n" "// Use case: Find actively maintained libraries with latest features\n" "{\n" " owner: 'huggingface',\n" " owner_type: 'org',\n" " sort: 'updated',\n" " order: 'desc',\n" " limit: 15\n" "}\n" "// Helps identify which repos have recent improvements/fixes\n" "" ), "parameters": { "type": "object", "properties": { "owner": { "type": "string", "description": "GitHub username or organization name. Required.", }, "owner_type": { "type": "string", "enum": ["user", "org"], "description": "Whether the owner is a 'user' or 'org'. Default: 'org'.", }, "sort": { "type": "string", "enum": ["stars", "forks", "updated", "created"], "description": "Sort field. Options: 'stars', 'forks', 'updated', 'created'. Default: 'stars'.", }, "order": { "type": "string", "enum": ["asc", "desc"], "description": "Sort order. Options: 'asc', 'desc'. Default: 'desc'.", }, "limit": { "type": "integer", "description": "Maximum number of repositories to return. No limit if not specified. Default: 30.", }, }, "required": ["owner"], }, } async def github_list_repos_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router""" try: result = list_repos( owner=arguments["owner"], owner_type=arguments.get("owner_type", "org"), sort=arguments.get("sort", "stars"), order=arguments.get("order", "desc"), limit=arguments.get("limit"), ) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error listing repositories: {str(e)}", False ================================================ FILE: agent/tools/github_read_file.py ================================================ """ GitHub Read File Tool - Read file contents from any GitHub repository with line range support Fetch exact file contents with metadata, supporting line ranges for efficient reading. """ import base64 import json import os from typing import Any, Dict, Optional import nbformat import requests from nbconvert import MarkdownExporter from nbconvert.preprocessors import ClearOutputPreprocessor, TagRemovePreprocessor from agent.tools.types import ToolResult def _convert_ipynb_to_markdown(content: str) -> str: """ Convert Jupyter notebook JSON to LLM-friendly Markdown. Args: content: Raw notebook JSON string Returns: Converted Markdown string """ try: # Parse notebook JSON nb_dict = json.loads(content) # Normalize cell sources (can be string or list of strings) if "cells" in nb_dict: for cell in nb_dict["cells"]: if "source" in cell and isinstance(cell["source"], list): cell["source"] = "".join(cell["source"]) # Read notebook with explicit version nb = nbformat.reads(json.dumps(nb_dict), as_version=4) # Strip outputs for LLM readability (outputs can be noisy/large) clear = ClearOutputPreprocessor() nb, _ = clear.preprocess(nb, {}) # Optionally remove cells tagged with "hide" or similar remove = TagRemovePreprocessor( remove_cell_tags={"hide", "hidden", "remove"}, remove_input_tags=set(), remove_all_outputs_tags=set(), ) nb, _ = remove.preprocess(nb, {}) # Convert to markdown exporter = MarkdownExporter() markdown, _ = exporter.from_notebook_node(nb) return markdown except json.JSONDecodeError: return content except Exception: return content def read_file( repo: str, path: str, ref: str = "HEAD", line_start: Optional[int] = None, line_end: Optional[int] = None, ) -> ToolResult: """ Read file contents from a GitHub repository with line range support. Args: repo: Repository in format "owner/repo" (e.g., "github/github-mcp-server") path: Path to file in repository (e.g., "pkg/github/search.go") ref: Git reference - branch name, tag, or commit SHA (default: "HEAD") line_start: Starting line number (1-indexed, inclusive) line_end: Ending line number (1-indexed, inclusive) Returns: ToolResult with file contents and metadata """ token = os.environ.get("GITHUB_TOKEN") if not token: return { "formatted": "Error: GITHUB_TOKEN environment variable is required", "totalResults": 0, "resultsShared": 0, "isError": True, } # Parse repo if "/" not in repo: return { "formatted": "Error: repo must be in format 'owner/repo'", "totalResults": 0, "resultsShared": 0, "isError": True, } owner, repo_name = repo.split("/", 1) headers = { "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28", "Authorization": f"Bearer {token}", } # Fetch file contents url = f"https://api.github.com/repos/{owner}/{repo_name}/contents/{path}" params = {} if ref and ref != "HEAD": params["ref"] = ref try: response = requests.get(url, headers=headers, params=params, timeout=30) if response.status_code == 404: return { "formatted": f"File not found: {path} in {repo} (ref: {ref})", "totalResults": 0, "resultsShared": 0, "isError": True, } if response.status_code != 200: error_msg = f"GitHub API error (status {response.status_code})" try: error_data = response.json() if "message" in error_data: error_msg += f": {error_data['message']}" except Exception: pass return { "formatted": error_msg, "totalResults": 0, "resultsShared": 0, "isError": True, } data = response.json() # Check if it's a file if data.get("type") != "file": return { "formatted": f"Path {path} is not a file (type: {data.get('type')})", "totalResults": 0, "resultsShared": 0, "isError": True, } # Decode content content_b64 = data.get("content", "") if content_b64: content_b64 = content_b64.replace("\n", "").replace(" ", "") content = base64.b64decode(content_b64).decode("utf-8", errors="replace") else: # For large files, fetch raw content raw_headers = { "Accept": "application/vnd.github.raw", "X-GitHub-Api-Version": "2022-11-28", "Authorization": f"Bearer {token}", } raw_response = requests.get( url, headers=raw_headers, params=params, timeout=30 ) if raw_response.status_code != 200: return { "formatted": "Failed to fetch file content", "totalResults": 0, "resultsShared": 0, "isError": True, } content = raw_response.text if path.lower().endswith(".ipynb"): content = _convert_ipynb_to_markdown(content) # Process line ranges lines = content.split("\n") total_lines = len(lines) truncated = False if line_start is None and line_end is None: # No range specified if total_lines > 300: line_start = 1 line_end = 300 truncated = True else: line_start = 1 line_end = total_lines else: # Range specified if line_start is None: line_start = 1 if line_end is None: line_end = total_lines # Validate range line_start = max(1, line_start) line_end = min(total_lines, line_end) if line_start > line_end: return { "formatted": f"Invalid range: line_start ({line_start}) > line_end ({line_end})", "totalResults": 0, "resultsShared": 0, "isError": True, } # Extract lines selected_lines = lines[line_start - 1 : line_end] selected_content = "\n".join(selected_lines) # Format output lines_output = [f"**Reading file from repo: {repo}, path: {path}**"] if ref and ref != "HEAD": lines_output.append(f"Ref: {ref}") lines_output.append("\n**File content:") lines_output.append("```") lines_output.append(selected_content) lines_output.append("```") if truncated: lines_output.append( f"Currently showing lines {line_start}-{line_end} out of {total_lines} total lines. Use line_start and line_end to view more lines." ) return { "formatted": "\n".join(lines_output), "totalResults": 1, "resultsShared": 1, } except requests.exceptions.RequestException as e: return { "formatted": f"Failed to connect to GitHub API: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } # Tool specification GITHUB_READ_FILE_TOOL_SPEC = { "name": "github_read_file", "description": ( "Read file contents from GitHub repositories. Returns first 300 lines by default. " "Auto-converts Jupyter notebooks to markdown.\n\n" "Use AFTER github_find_examples to study the working implementation. " "The purpose is to learn current API patterns — imports, trainer configs, dataset handling — " "so your implementation uses correct, up-to-date code.\n\n" "Use line_start/line_end for large files (>300 lines) to read specific sections.\n\n" "When NOT to use: when you don't know the file path (use github_find_examples first)." ), "parameters": { "type": "object", "properties": { "repo": { "type": "string", "description": "Repository in format 'owner/repo' (e.g., 'github/github-mcp-server'). Required.", }, "path": { "type": "string", "description": "Path to file in repository (e.g., 'src/index.js'). Required.", }, "ref": { "type": "string", "description": "Git reference - branch name, tag, or commit SHA. Default: 'HEAD'.", }, "line_start": { "type": "integer", "description": "Starting line number (1-indexed, inclusive). Optional.", }, "line_end": { "type": "integer", "description": "Ending line number (1-indexed, inclusive). Optional.", }, }, "required": ["repo", "path"], }, } async def github_read_file_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router""" try: result = read_file( repo=arguments["repo"], path=arguments["path"], ref=arguments.get("ref", "HEAD"), line_start=arguments.get("line_start"), line_end=arguments.get("line_end"), ) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error reading file: {str(e)}", False ================================================ FILE: agent/tools/hf_repo_files_tool.py ================================================ """ HF Repo Files Tool - File operations on Hugging Face repositories Operations: list, read, upload, delete """ import asyncio from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError from agent.tools.types import ToolResult OperationType = Literal["list", "read", "upload", "delete"] async def _async_call(func, *args, **kwargs): """Wrap synchronous HfApi calls for async context.""" return await asyncio.to_thread(func, *args, **kwargs) def _build_repo_url(repo_id: str, repo_type: str = "model") -> str: """Build the Hub URL for a repository.""" if repo_type == "model": return f"https://huggingface.co/{repo_id}" return f"https://huggingface.co/{repo_type}s/{repo_id}" def _format_size(size_bytes: int) -> str: """Format file size in human-readable form.""" for unit in ["B", "KB", "MB", "GB", "TB"]: if size_bytes < 1024: return f"{size_bytes:.1f}{unit}" size_bytes /= 1024 return f"{size_bytes:.1f}PB" class HfRepoFilesTool: """Tool for file operations on HF repos.""" def __init__(self, hf_token: Optional[str] = None): self.api = HfApi(token=hf_token) async def execute(self, args: Dict[str, Any]) -> ToolResult: """Execute the specified operation.""" operation = args.get("operation") if not operation: return self._help() try: handlers = { "list": self._list, "read": self._read, "upload": self._upload, "delete": self._delete, } handler = handlers.get(operation) if handler: return await handler(args) else: return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete") except RepositoryNotFoundError: return self._error(f"Repository not found: {args.get('repo_id')}") except EntryNotFoundError: return self._error(f"File not found: {args.get('path')}") except Exception as e: return self._error(f"Error: {str(e)}") def _help(self) -> ToolResult: """Show usage instructions.""" return { "formatted": """**hf_repo_files** - File operations on HF repos **Operations:** - `list` - List files: `{"operation": "list", "repo_id": "gpt2"}` - `read` - Read file: `{"operation": "read", "repo_id": "gpt2", "path": "config.json"}` - `upload` - Upload: `{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "..."}` - `delete` - Delete: `{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp"]}` **Common params:** repo_id (required), repo_type (model/dataset/space), revision (default: main)""", "totalResults": 1, "resultsShared": 1, } async def _list(self, args: Dict[str, Any]) -> ToolResult: """List files in a repository.""" repo_id = args.get("repo_id") if not repo_id: return self._error("repo_id is required") repo_type = args.get("repo_type", "model") revision = args.get("revision", "main") path = args.get("path", "") items = list(await _async_call( self.api.list_repo_tree, repo_id=repo_id, repo_type=repo_type, revision=revision, path_in_repo=path, recursive=True, )) if not items: return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0} lines = [] total_size = 0 for item in sorted(items, key=lambda x: x.path): if hasattr(item, "size") and item.size: total_size += item.size lines.append(f"{item.path} ({_format_size(item.size)})") else: lines.append(f"{item.path}/") url = _build_repo_url(repo_id, repo_type) response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines) return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)} async def _read(self, args: Dict[str, Any]) -> ToolResult: """Read file content from a repository.""" repo_id = args.get("repo_id") path = args.get("path") if not repo_id: return self._error("repo_id is required") if not path: return self._error("path is required") repo_type = args.get("repo_type", "model") revision = args.get("revision", "main") max_chars = args.get("max_chars", 50000) file_path = await _async_call( hf_hub_download, repo_id=repo_id, filename=path, repo_type=repo_type, revision=revision, token=self.api.token, ) try: with open(file_path, "r", encoding="utf-8") as f: content = f.read() truncated = len(content) > max_chars if truncated: content = content[:max_chars] url = f"{_build_repo_url(repo_id, repo_type)}/blob/{revision}/{path}" response = f"**{path}**{' (truncated)' if truncated else ''}\n{url}\n\n```\n{content}\n```" return {"formatted": response, "totalResults": 1, "resultsShared": 1} except UnicodeDecodeError: import os size = os.path.getsize(file_path) return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1} async def _upload(self, args: Dict[str, Any]) -> ToolResult: """Upload content to a repository.""" repo_id = args.get("repo_id") path = args.get("path") content = args.get("content") if not repo_id: return self._error("repo_id is required") if not path: return self._error("path is required") if content is None: return self._error("content is required") repo_type = args.get("repo_type", "model") revision = args.get("revision", "main") create_pr = args.get("create_pr", False) commit_message = args.get("commit_message", f"Upload {path}") file_bytes = content.encode("utf-8") if isinstance(content, str) else content result = await _async_call( self.api.upload_file, path_or_fileobj=file_bytes, path_in_repo=path, repo_id=repo_id, repo_type=repo_type, revision=revision, commit_message=commit_message, create_pr=create_pr, ) url = _build_repo_url(repo_id, repo_type) if create_pr and hasattr(result, "pr_url"): response = f"**Uploaded as PR**\n{result.pr_url}" else: response = f"**Uploaded:** {path}\n{url}/blob/{revision}/{path}" return {"formatted": response, "totalResults": 1, "resultsShared": 1} async def _delete(self, args: Dict[str, Any]) -> ToolResult: """Delete files from a repository.""" repo_id = args.get("repo_id") patterns = args.get("patterns") if not repo_id: return self._error("repo_id is required") if not patterns: return self._error("patterns is required (list of paths/wildcards)") if isinstance(patterns, str): patterns = [patterns] repo_type = args.get("repo_type", "model") revision = args.get("revision", "main") create_pr = args.get("create_pr", False) commit_message = args.get("commit_message", f"Delete {', '.join(patterns)}") await _async_call( self.api.delete_files, repo_id=repo_id, delete_patterns=patterns, repo_type=repo_type, revision=revision, commit_message=commit_message, create_pr=create_pr, ) response = f"**Deleted:** {', '.join(patterns)} from {repo_id}" return {"formatted": response, "totalResults": 1, "resultsShared": 1} def _error(self, message: str) -> ToolResult: """Return an error result.""" return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} # Tool specification HF_REPO_FILES_TOOL_SPEC = { "name": "hf_repo_files", "description": ( "Read and write files in HF repos (models/datasets/spaces).\n\n" "## Operations\n" "- **list**: List files with sizes and structure\n" "- **read**: Read file content (text files only)\n" "- **upload**: Upload content to repo (can create PR)\n" "- **delete**: Delete files/folders (supports wildcards like *.tmp)\n\n" "## Use when\n" "- Need to see what files exist in a repo\n" "- Want to read config.json, README.md, or other text files\n" "- Uploading training scripts, configs, or results to a repo\n" "- Cleaning up temporary files from a repo\n\n" "## Examples\n" '{"operation": "list", "repo_id": "meta-llama/Llama-2-7b"}\n' '{"operation": "read", "repo_id": "gpt2", "path": "config.json"}\n' '{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "# My Model"}\n' '{"operation": "upload", "repo_id": "org/model", "path": "fix.py", "content": "...", "create_pr": true}\n' '{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp", "logs/"]}\n\n' "## Notes\n" "- For binary files (safetensors, bin), use list to see them but can't read content\n" "- upload/delete require approval (can overwrite/destroy data)\n" "- Use create_pr=true to propose changes instead of direct commit\n" ), "parameters": { "type": "object", "properties": { "operation": { "type": "string", "enum": ["list", "read", "upload", "delete"], "description": "Operation: list, read, upload, delete", }, "repo_id": { "type": "string", "description": "Repository ID (e.g., 'username/repo-name')", }, "repo_type": { "type": "string", "enum": ["model", "dataset", "space"], "description": "Repository type (default: model)", }, "revision": { "type": "string", "description": "Branch/tag/commit (default: main)", }, "path": { "type": "string", "description": "File path for read/upload", }, "content": { "type": "string", "description": "File content for upload", }, "patterns": { "type": "array", "items": {"type": "string"}, "description": "Patterns to delete (e.g., ['*.tmp', 'logs/'])", }, "create_pr": { "type": "boolean", "description": "Create PR instead of direct commit", }, "commit_message": { "type": "string", "description": "Custom commit message", }, }, "required": ["operation"], }, } async def hf_repo_files_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]: """Handler for agent tool router.""" try: hf_token = session.hf_token if session else None tool = HfRepoFilesTool(hf_token=hf_token) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error: {str(e)}", False ================================================ FILE: agent/tools/hf_repo_git_tool.py ================================================ """ HF Repo Git Tool - Git-like operations on Hugging Face repositories Operations: branches, tags, PRs, repo management """ import asyncio from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError from agent.tools.types import ToolResult OperationType = Literal[ "create_branch", "delete_branch", "create_tag", "delete_tag", "list_refs", "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", "create_repo", "update_repo", ] async def _async_call(func, *args, **kwargs): """Wrap synchronous HfApi calls for async context.""" return await asyncio.to_thread(func, *args, **kwargs) def _build_repo_url(repo_id: str, repo_type: str = "model") -> str: """Build the Hub URL for a repository.""" if repo_type == "model": return f"https://huggingface.co/{repo_id}" return f"https://huggingface.co/{repo_type}s/{repo_id}" class HfRepoGitTool: """Tool for git-like operations on HF repos.""" def __init__(self, hf_token: Optional[str] = None): self.api = HfApi(token=hf_token) async def execute(self, args: Dict[str, Any]) -> ToolResult: """Execute the specified operation.""" operation = args.get("operation") if not operation: return self._help() try: handlers = { "create_branch": self._create_branch, "delete_branch": self._delete_branch, "create_tag": self._create_tag, "delete_tag": self._delete_tag, "list_refs": self._list_refs, "create_pr": self._create_pr, "list_prs": self._list_prs, "get_pr": self._get_pr, "merge_pr": self._merge_pr, "close_pr": self._close_pr, "comment_pr": self._comment_pr, "change_pr_status": self._change_pr_status, "create_repo": self._create_repo, "update_repo": self._update_repo, } handler = handlers.get(operation) if handler: return await handler(args) else: ops = ", ".join(handlers.keys()) return self._error(f"Unknown operation: {operation}. Valid: {ops}") except RepositoryNotFoundError: return self._error(f"Repository not found: {args.get('repo_id')}") except Exception as e: return self._error(f"Error: {str(e)}") def _help(self) -> ToolResult: """Show usage instructions.""" return { "formatted": """**hf_repo_git** - Git-like operations on HF repos **Branch/Tag:** - `create_branch`: `{"operation": "create_branch", "repo_id": "...", "branch": "dev"}` - `delete_branch`: `{"operation": "delete_branch", "repo_id": "...", "branch": "dev"}` - `create_tag`: `{"operation": "create_tag", "repo_id": "...", "tag": "v1.0"}` - `delete_tag`: `{"operation": "delete_tag", "repo_id": "...", "tag": "v1.0"}` - `list_refs`: `{"operation": "list_refs", "repo_id": "..."}` **PRs:** - `create_pr`: `{"operation": "create_pr", "repo_id": "...", "title": "..."}` (creates draft PR) - `list_prs`: `{"operation": "list_prs", "repo_id": "..."}` (shows status: draft/open/merged/closed) - `get_pr`: `{"operation": "get_pr", "repo_id": "...", "pr_num": 1}` (shows status) - `change_pr_status`: `{"operation": "change_pr_status", "repo_id": "...", "pr_num": 1, "new_status": "open"}` (change draft to open) - `merge_pr`: `{"operation": "merge_pr", "repo_id": "...", "pr_num": 1}` - `close_pr`: `{"operation": "close_pr", "repo_id": "...", "pr_num": 1}` - `comment_pr`: `{"operation": "comment_pr", "repo_id": "...", "pr_num": 1, "comment": "..."}` **Repo:** - `create_repo`: `{"operation": "create_repo", "repo_id": "my-model", "private": true}` - `update_repo`: `{"operation": "update_repo", "repo_id": "...", "private": false}`""", "totalResults": 1, "resultsShared": 1, } # ========================================================================= # BRANCH OPERATIONS # ========================================================================= async def _create_branch(self, args: Dict[str, Any]) -> ToolResult: """Create a new branch.""" repo_id = args.get("repo_id") branch = args.get("branch") if not repo_id: return self._error("repo_id is required") if not branch: return self._error("branch is required") repo_type = args.get("repo_type", "model") from_rev = args.get("from_rev", "main") await _async_call( self.api.create_branch, repo_id=repo_id, branch=branch, revision=from_rev, repo_type=repo_type, exist_ok=args.get("exist_ok", False), ) url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}" return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1} async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult: """Delete a branch.""" repo_id = args.get("repo_id") branch = args.get("branch") if not repo_id: return self._error("repo_id is required") if not branch: return self._error("branch is required") repo_type = args.get("repo_type", "model") await _async_call( self.api.delete_branch, repo_id=repo_id, branch=branch, repo_type=repo_type, ) return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # TAG OPERATIONS # ========================================================================= async def _create_tag(self, args: Dict[str, Any]) -> ToolResult: """Create a tag.""" repo_id = args.get("repo_id") tag = args.get("tag") if not repo_id: return self._error("repo_id is required") if not tag: return self._error("tag is required") repo_type = args.get("repo_type", "model") revision = args.get("revision", "main") tag_message = args.get("tag_message", "") await _async_call( self.api.create_tag, repo_id=repo_id, tag=tag, revision=revision, tag_message=tag_message, repo_type=repo_type, exist_ok=args.get("exist_ok", False), ) url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}" return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1} async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult: """Delete a tag.""" repo_id = args.get("repo_id") tag = args.get("tag") if not repo_id: return self._error("repo_id is required") if not tag: return self._error("tag is required") repo_type = args.get("repo_type", "model") await _async_call( self.api.delete_tag, repo_id=repo_id, tag=tag, repo_type=repo_type, ) return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # LIST REFS # ========================================================================= async def _list_refs(self, args: Dict[str, Any]) -> ToolResult: """List branches and tags.""" repo_id = args.get("repo_id") if not repo_id: return self._error("repo_id is required") repo_type = args.get("repo_type", "model") refs = await _async_call( self.api.list_repo_refs, repo_id=repo_id, repo_type=repo_type, ) branches = [b.name for b in refs.branches] if refs.branches else [] tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else [] url = _build_repo_url(repo_id, repo_type) lines = [f"**{repo_id}**", url, ""] if branches: lines.append(f"**Branches ({len(branches)}):** " + ", ".join(branches)) else: lines.append("**Branches:** none") if tags: lines.append(f"**Tags ({len(tags)}):** " + ", ".join(tags)) else: lines.append("**Tags:** none") return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)} # ========================================================================= # PR OPERATIONS # ========================================================================= async def _create_pr(self, args: Dict[str, Any]) -> ToolResult: """Create a pull request.""" repo_id = args.get("repo_id") title = args.get("title") if not repo_id: return self._error("repo_id is required") if not title: return self._error("title is required") repo_type = args.get("repo_type", "model") description = args.get("description", "") result = await _async_call( self.api.create_pull_request, repo_id=repo_id, title=title, description=description, repo_type=repo_type, ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}" return { "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"", "totalResults": 1, "resultsShared": 1, } async def _list_prs(self, args: Dict[str, Any]) -> ToolResult: """List PRs and discussions.""" repo_id = args.get("repo_id") if not repo_id: return self._error("repo_id is required") repo_type = args.get("repo_type", "model") status = args.get("status", "all") # open, closed, all discussions = list(self.api.get_repo_discussions( repo_id=repo_id, repo_type=repo_type, discussion_status=status if status != "all" else None, )) if not discussions: return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0} url = _build_repo_url(repo_id, repo_type) lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""] for d in discussions[:20]: if d.status == "draft": status_label = "[DRAFT]" elif d.status == "open": status_label = "[OPEN]" elif d.status == "merged": status_label = "[MERGED]" else: status_label = "[CLOSED]" type_label = "PR" if d.is_pull_request else "D" lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}") return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))} async def _get_pr(self, args: Dict[str, Any]) -> ToolResult: """Get PR details.""" repo_id = args.get("repo_id") pr_num = args.get("pr_num") if not repo_id: return self._error("repo_id is required") if not pr_num: return self._error("pr_num is required") repo_type = args.get("repo_type", "model") pr = await _async_call( self.api.get_discussion_details, repo_id=repo_id, discussion_num=int(pr_num), repo_type=repo_type, ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" status_map = { "draft": "Draft", "open": "Open", "merged": "Merged", "closed": "Closed" } status = status_map.get(pr.status, pr.status.capitalize()) type_label = "Pull Request" if pr.is_pull_request else "Discussion" lines = [ f"**{type_label} #{pr_num}:** {pr.title}", f"**Status:** {status}", f"**Author:** {pr.author}", url, ] if pr.is_pull_request: if pr.status == "draft": lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") elif pr.status == "open": lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1} async def _merge_pr(self, args: Dict[str, Any]) -> ToolResult: """Merge a pull request.""" repo_id = args.get("repo_id") pr_num = args.get("pr_num") if not repo_id: return self._error("repo_id is required") if not pr_num: return self._error("pr_num is required") repo_type = args.get("repo_type", "model") comment = args.get("comment", "") await _async_call( self.api.merge_pull_request, repo_id=repo_id, discussion_num=int(pr_num), comment=comment, repo_type=repo_type, ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1} async def _close_pr(self, args: Dict[str, Any]) -> ToolResult: """Close a PR/discussion.""" repo_id = args.get("repo_id") pr_num = args.get("pr_num") if not repo_id: return self._error("repo_id is required") if not pr_num: return self._error("pr_num is required") repo_type = args.get("repo_type", "model") comment = args.get("comment", "") await _async_call( self.api.change_discussion_status, repo_id=repo_id, discussion_num=int(pr_num), new_status="closed", comment=comment, repo_type=repo_type, ) return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1} async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult: """Add a comment to a PR/discussion.""" repo_id = args.get("repo_id") pr_num = args.get("pr_num") comment = args.get("comment") if not repo_id: return self._error("repo_id is required") if not pr_num: return self._error("pr_num is required") if not comment: return self._error("comment is required") repo_type = args.get("repo_type", "model") await _async_call( self.api.comment_discussion, repo_id=repo_id, discussion_num=int(pr_num), comment=comment, repo_type=repo_type, ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1} async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult: """Change PR/discussion status (mainly to convert draft to open).""" repo_id = args.get("repo_id") pr_num = args.get("pr_num") new_status = args.get("new_status") if not repo_id: return self._error("repo_id is required") if not pr_num: return self._error("pr_num is required") if not new_status: return self._error("new_status is required (open or closed)") repo_type = args.get("repo_type", "model") comment = args.get("comment", "") await _async_call( self.api.change_discussion_status, repo_id=repo_id, discussion_num=int(pr_num), new_status=new_status, comment=comment, repo_type=repo_type, ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1} # ========================================================================= # REPO MANAGEMENT # ========================================================================= async def _create_repo(self, args: Dict[str, Any]) -> ToolResult: """Create a new repository.""" repo_id = args.get("repo_id") if not repo_id: return self._error("repo_id is required") repo_type = args.get("repo_type", "model") private = args.get("private", True) space_sdk = args.get("space_sdk") if repo_type == "space" and not space_sdk: return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)") kwargs = { "repo_id": repo_id, "repo_type": repo_type, "private": private, "exist_ok": args.get("exist_ok", False), } if space_sdk: kwargs["space_sdk"] = space_sdk result = await _async_call(self.api.create_repo, **kwargs) return { "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}", "totalResults": 1, "resultsShared": 1, } async def _update_repo(self, args: Dict[str, Any]) -> ToolResult: """Update repository settings.""" repo_id = args.get("repo_id") if not repo_id: return self._error("repo_id is required") repo_type = args.get("repo_type", "model") private = args.get("private") gated = args.get("gated") if private is None and gated is None: return self._error("Specify private (bool) or gated ('auto'/'manual'/false)") kwargs = {"repo_id": repo_id, "repo_type": repo_type} if private is not None: kwargs["private"] = private if gated is not None: kwargs["gated"] = gated await _async_call(self.api.update_repo_settings, **kwargs) changes = [] if private is not None: changes.append(f"private={private}") if gated is not None: changes.append(f"gated={gated}") url = f"{_build_repo_url(repo_id, repo_type)}/settings" return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1} def _error(self, message: str) -> ToolResult: """Return an error result.""" return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} # Tool specification HF_REPO_GIT_TOOL_SPEC = { "name": "hf_repo_git", "description": ( "Git-like operations on HF repos: branches, tags, PRs, and repo management.\n\n" "## Operations\n" "**Branches:** create_branch, delete_branch, list_refs\n" "**Tags:** create_tag, delete_tag\n" "**PRs:** create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr, change_pr_status\n" "**Repo:** create_repo, update_repo\n\n" "## Use when\n" "- Creating feature branches for experiments\n" "- Tagging model versions (v1.0, v2.0)\n" "- Opening PRs to contribute to repos you don't own\n" "- Reviewing and merging PRs on your repos\n" "- Creating new model/dataset/space repos\n" "- Changing repo visibility (public/private) or gated access\n\n" "## Examples\n" '{"operation": "list_refs", "repo_id": "my-model"}\n' '{"operation": "create_branch", "repo_id": "my-model", "branch": "experiment-v2"}\n' '{"operation": "create_tag", "repo_id": "my-model", "tag": "v1.0", "revision": "main"}\n' '{"operation": "create_pr", "repo_id": "org/model", "title": "Fix tokenizer config"}\n' '{"operation": "change_pr_status", "repo_id": "my-model", "pr_num": 1, "new_status": "open"}\n' '{"operation": "merge_pr", "repo_id": "my-model", "pr_num": 3}\n' '{"operation": "create_repo", "repo_id": "my-new-model", "private": true}\n' '{"operation": "update_repo", "repo_id": "my-model", "gated": "auto"}\n\n' "## PR Workflow\n" "1. create_pr → creates draft PR (empty by default)\n" "2. Upload files with revision='refs/pr/N' to add commits\n" "3. change_pr_status with new_status='open' to publish (convert draft to open)\n" "4. merge_pr when ready\n\n" "## Notes\n" "- PR status: draft (default), open, merged, closed\n" "- delete_branch, delete_tag, merge_pr, create_repo, update_repo require approval\n" "- For spaces, create_repo needs space_sdk (gradio/streamlit/docker/static)\n" "- gated options: 'auto' (instant), 'manual' (review), false (open)\n" ), "parameters": { "type": "object", "properties": { "operation": { "type": "string", "enum": [ "create_branch", "delete_branch", "create_tag", "delete_tag", "list_refs", "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", "create_repo", "update_repo", ], "description": "Operation to execute", }, "repo_id": { "type": "string", "description": "Repository ID (e.g., 'username/repo-name')", }, "repo_type": { "type": "string", "enum": ["model", "dataset", "space"], "description": "Repository type (default: model)", }, "branch": { "type": "string", "description": "Branch name (create_branch, delete_branch)", }, "from_rev": { "type": "string", "description": "Create branch from this revision (default: main)", }, "tag": { "type": "string", "description": "Tag name (create_tag, delete_tag)", }, "revision": { "type": "string", "description": "Revision for tag (default: main)", }, "tag_message": { "type": "string", "description": "Tag description", }, "title": { "type": "string", "description": "PR title (create_pr)", }, "description": { "type": "string", "description": "PR description (create_pr)", }, "pr_num": { "type": "integer", "description": "PR/discussion number", }, "comment": { "type": "string", "description": "Comment text", }, "status": { "type": "string", "enum": ["open", "closed", "all"], "description": "Filter PRs by status (list_prs)", }, "new_status": { "type": "string", "enum": ["open", "closed"], "description": "New status for PR/discussion (change_pr_status)", }, "private": { "type": "boolean", "description": "Make repo private (create_repo, update_repo)", }, "gated": { "type": "string", "enum": ["auto", "manual", "false"], "description": "Gated access setting (update_repo)", }, "space_sdk": { "type": "string", "enum": ["gradio", "streamlit", "docker", "static"], "description": "Space SDK (required for create_repo with space)", }, }, "required": ["operation"], }, } async def hf_repo_git_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]: """Handler for agent tool router.""" try: hf_token = session.hf_token if session else None tool = HfRepoGitTool(hf_token=hf_token) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error: {str(e)}", False ================================================ FILE: agent/tools/jobs_tool.py ================================================ """ Hugging Face Jobs Tool - Using huggingface-hub library Refactored to use official huggingface-hub library instead of custom HTTP client """ import asyncio import base64 import http.client import os import re from typing import Any, Dict, Literal, Optional, Callable, Awaitable import logging import httpx from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError from agent.core.session import Event from agent.tools.types import ToolResult logger = logging.getLogger(__name__) from agent.tools.utilities import ( format_job_details, format_jobs_table, format_scheduled_job_details, format_scheduled_jobs_table, ) # Hardware flavors CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"] GPU_FLAVORS = [ "t4-small", "t4-medium", "a10g-small", "a10g-large", "a10g-largex2", "a10g-largex4", "a100-large", "a100x4", "a100x8", "l4x1", "l4x4", "l40sx1", "l40sx4", "l40sx8", ] # Detailed specs for display (vCPU/RAM/GPU VRAM) CPU_FLAVORS_DESC = "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB)" GPU_FLAVORS_DESC = ( "t4-small(4vCPU/15GB/GPU 16GB), t4-medium(8vCPU/30GB/GPU 16GB), " "a10g-small(4vCPU/15GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), " "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), " "a100-large(12vCPU/142GB/GPU 80GB), a100x4(48vCPU/568GB/GPU 320GB), a100x8(96vCPU/1136GB/GPU 640GB), " "l4x1(8vCPU/30GB/GPU 24GB), l4x4(48vCPU/186GB/GPU 96GB), " "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB)" ) SPECIALIZED_FLAVORS = ["inf2x6"] ALL_FLAVORS = CPU_FLAVORS + GPU_FLAVORS + SPECIALIZED_FLAVORS # Operation names OperationType = Literal[ "run", "ps", "logs", "inspect", "cancel", "scheduled run", "scheduled ps", "scheduled inspect", "scheduled delete", "scheduled suspend", "scheduled resume", ] # Constants UV_DEFAULT_IMAGE = "ghcr.io/astral-sh/uv:python3.12-bookworm" def _filter_uv_install_output(logs: list[str]) -> list[str]: """ Filter out UV package installation output from logs. Replaces installation details with "[installs truncated]" and keeps the "Installed X packages in Y ms/s" summary line. Args: logs: List of log lines Returns: Filtered list of log lines """ if not logs: return logs # Regex pattern to match: "Installed X packages in Y ms" or "Installed X package in Y s" install_pattern = re.compile( r"^Installed\s+\d+\s+packages?\s+in\s+\d+(?:\.\d+)?\s*(?:ms|s)$" ) # Find the index of the "Installed X packages" line install_line_idx = None for idx, line in enumerate(logs): if install_pattern.match(line.strip()): install_line_idx = idx break # If pattern found, replace installation details with truncation message if install_line_idx is not None and install_line_idx > 0: # Keep logs from the "Installed X packages" line onward # Add truncation message before the "Installed" line return ["[installs truncated]"] + logs[install_line_idx:] # If pattern not found, return original logs return logs _ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07') def _strip_ansi(text: str) -> str: return _ANSI_RE.sub('', text) _DEFAULT_ENV = { "HF_HUB_DISABLE_PROGRESS_BARS": "1", "TQDM_DISABLE": "1", "TRANSFORMERS_VERBOSITY": "warning", "HF_HUB_ENABLE_HF_TRANSFER": "1", "UV_NO_PROGRESS": "1", } def _add_default_env(params: Dict[str, Any] | None) -> Dict[str, Any]: """Inject default env vars for clean, agent-friendly output.""" result = dict(_DEFAULT_ENV) result.update(params or {}) # user-provided values override defaults return result def _add_environment_variables( params: Dict[str, Any] | None, user_token: str | None = None ) -> Dict[str, Any]: token = user_token or "" # Start with user-provided env vars, then force-set token last result = dict(params or {}) # If the caller passed HF_TOKEN="$HF_TOKEN", ignore it. if result.get("HF_TOKEN", "").strip().startswith("$"): result.pop("HF_TOKEN", None) # Set both names to be safe (different libs check different vars) if token: result["HF_TOKEN"] = token result["HUGGINGFACE_HUB_TOKEN"] = token return result def _build_uv_command( script: str, with_deps: list[str] | None = None, python: str | None = None, script_args: list[str] | None = None, ) -> list[str]: """Build UV run command""" parts = ["uv", "run"] if with_deps: for dep in with_deps: parts.extend(["--with", dep]) if python: parts.extend(["-p", python]) parts.append(script) if script_args: parts.extend(script_args) # add defaults # parts.extend(["--push_to_hub"]) return parts def _wrap_inline_script( script: str, with_deps: list[str] | None = None, python: str | None = None, script_args: list[str] | None = None, ) -> str: """Wrap inline script with base64 encoding to avoid file creation""" encoded = base64.b64encode(script.encode("utf-8")).decode("utf-8") # Build the uv command with stdin (-) uv_command = _build_uv_command("-", with_deps, python, script_args) # Join command parts with proper spacing uv_command_str = " ".join(uv_command) return f'echo "{encoded}" | base64 -d | {uv_command_str}' def _ensure_hf_transfer_dependency(deps: list[str] | None) -> list[str]: """Ensure hf-transfer is included in the dependencies list""" if isinstance(deps, list): deps_copy = deps.copy() # Don't modify the original if "hf-transfer" not in deps_copy: deps_copy.append("hf-transfer") return deps_copy return ["hf-transfer"] def _resolve_uv_command( script: str, with_deps: list[str] | None = None, python: str | None = None, script_args: list[str] | None = None, ) -> list[str]: """Resolve UV command based on script source (URL, inline, or file path)""" # If URL, use directly if script.startswith("http://") or script.startswith("https://"): return _build_uv_command(script, with_deps, python, script_args) # If contains newline, treat as inline script if "\n" in script: wrapped = _wrap_inline_script(script, with_deps, python, script_args) return ["/bin/sh", "-lc", wrapped] # Otherwise, treat as file path return _build_uv_command(script, with_deps, python, script_args) async def _async_call(func, *args, **kwargs): """Wrap synchronous HfApi calls for async context""" return await asyncio.to_thread(func, *args, **kwargs) def _job_info_to_dict(job_info) -> Dict[str, Any]: """Convert JobInfo object to dictionary for formatting functions""" return { "id": job_info.id, "status": {"stage": job_info.status.stage, "message": job_info.status.message}, "command": job_info.command, "createdAt": job_info.created_at.isoformat(), "dockerImage": job_info.docker_image, "spaceId": job_info.space_id, "hardware_flavor": job_info.flavor, "owner": {"name": job_info.owner.name}, } def _scheduled_job_info_to_dict(scheduled_job_info) -> Dict[str, Any]: """Convert ScheduledJobInfo object to dictionary for formatting functions""" job_spec = scheduled_job_info.job_spec # Extract last run and next run from status last_run = None next_run = None if scheduled_job_info.status: if scheduled_job_info.status.last_job: last_run = scheduled_job_info.status.last_job.created_at if last_run: last_run = ( last_run.isoformat() if hasattr(last_run, "isoformat") else str(last_run) ) if scheduled_job_info.status.next_job_run_at: next_run = scheduled_job_info.status.next_job_run_at next_run = ( next_run.isoformat() if hasattr(next_run, "isoformat") else str(next_run) ) return { "id": scheduled_job_info.id, "schedule": scheduled_job_info.schedule, "suspend": scheduled_job_info.suspend, "lastRun": last_run, "nextRun": next_run, "jobSpec": { "dockerImage": job_spec.docker_image, "spaceId": job_spec.space_id, "command": job_spec.command or [], "hardware_flavor": job_spec.flavor or "cpu-basic", }, } class HfJobsTool: """Tool for managing Hugging Face compute jobs using huggingface-hub library""" def __init__( self, hf_token: Optional[str] = None, namespace: Optional[str] = None, log_callback: Optional[Callable[[str], Awaitable[None]]] = None, session: Any = None, tool_call_id: Optional[str] = None, ): self.hf_token = hf_token self.api = HfApi(token=hf_token) self.namespace = namespace self.log_callback = log_callback self.session = session self.tool_call_id = tool_call_id async def execute(self, params: Dict[str, Any]) -> ToolResult: """Execute the specified operation""" operation = params.get("operation") args = params # If no operation provided, return error if not operation: return { "formatted": "Error: 'operation' parameter is required. See tool description for available operations and usage examples.", "totalResults": 0, "resultsShared": 0, "isError": True, } # Normalize operation name operation = operation.lower() try: # Route to appropriate handler if operation == "run": return await self._run_job(args) elif operation == "ps": return await self._list_jobs(args) elif operation == "logs": return await self._get_logs(args) elif operation == "inspect": return await self._inspect_job(args) elif operation == "cancel": return await self._cancel_job(args) elif operation == "scheduled run": return await self._scheduled_run(args) elif operation == "scheduled ps": return await self._list_scheduled_jobs(args) elif operation == "scheduled inspect": return await self._inspect_scheduled_job(args) elif operation == "scheduled delete": return await self._delete_scheduled_job(args) elif operation == "scheduled suspend": return await self._suspend_scheduled_job(args) elif operation == "scheduled resume": return await self._resume_scheduled_job(args) else: return { "formatted": f'Unknown operation: "{operation}"\n\n' "Available operations:\n" "- run, ps, logs, inspect, cancel\n" "- scheduled run, scheduled ps, scheduled inspect, " "scheduled delete, scheduled suspend, scheduled resume\n\n" "Call this tool with no operation for full usage instructions.", "totalResults": 0, "resultsShared": 0, "isError": True, } except HfHubHTTPError as e: return { "formatted": f"API Error: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } except Exception as e: return { "formatted": f"Error executing {operation}: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } async def _wait_for_job_completion( self, job_id: str, namespace: Optional[str] = None ) -> tuple[str, list[str]]: """ Stream job logs until completion, printing them in real-time. Implements retry logic to handle connection drops during long-running jobs. Returns: tuple: (final_status, all_logs) """ all_logs = [] terminal_states = {"COMPLETED", "FAILED", "CANCELED", "ERROR"} max_retries = 100 # Allow many retries for 8h+ jobs retry_delay = 5 # Seconds between retries for _ in range(max_retries): try: # Use a queue to bridge sync generator to async consumer queue = asyncio.Queue() loop = asyncio.get_running_loop() def log_producer(): try: # fetch_job_logs is a blocking sync generator logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace) for line in logs_gen: # Push line to queue thread-safely loop.call_soon_threadsafe(queue.put_nowait, line) # Signal EOF loop.call_soon_threadsafe(queue.put_nowait, None) except Exception as e: # Signal error loop.call_soon_threadsafe(queue.put_nowait, e) # Start producer in a background thread so it doesn't block the event loop producer_future = loop.run_in_executor(None, log_producer) # Consume logs from the queue as they arrive while True: item = await queue.get() # EOF sentinel if item is None: break # Error occurred in producer if isinstance(item, Exception): raise item # Process log line log_line = item logger.debug(log_line) if self.log_callback: await self.log_callback(log_line) all_logs.append(log_line) # If we get here, streaming completed normally (EOF received) # Wait for thread to cleanup (should be done) await producer_future break except ( ConnectionError, TimeoutError, OSError, http.client.IncompleteRead, httpx.RemoteProtocolError, httpx.ReadError, HfHubHTTPError, ) as e: # Connection dropped - check if job is still running try: job_info = await _async_call( self.api.inspect_job, job_id=job_id, namespace=namespace ) current_status = job_info.status.stage if current_status in terminal_states: # Job finished, no need to retry logger.info(f"Job reached terminal state: {current_status}") break # Job still running, retry connection logger.warning( f"Connection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..." ) await asyncio.sleep(retry_delay) continue except (ConnectionError, TimeoutError, OSError): # Can't even check job status, wait and retry logger.warning(f"Connection error, retrying in {retry_delay}s...") await asyncio.sleep(retry_delay) continue # Fetch final job status — retry briefly if still RUNNING # (the API may lag a few seconds behind the log stream ending) final_status = "UNKNOWN" for _ in range(6): job_info = await _async_call( self.api.inspect_job, job_id=job_id, namespace=namespace ) final_status = job_info.status.stage if final_status in terminal_states: break await asyncio.sleep(2.5) return final_status, all_logs async def _run_job(self, args: Dict[str, Any]) -> ToolResult: """Run a job using HfApi.run_job() - smart detection of Python vs Docker mode""" try: script = args.get("script") command = args.get("command") # Validate mutually exclusive parameters if script and command: raise ValueError( "'script' and 'command' are mutually exclusive. Provide one or the other, not both." ) if not script and not command: raise ValueError( "Either 'script' (for Python) or 'command' (for Docker) must be provided." ) # Python mode: script provided if script: # Get dependencies and ensure hf-transfer is included deps = _ensure_hf_transfer_dependency(args.get("dependencies")) # Resolve the command based on script type (URL, inline, or file) command = _resolve_uv_command( script=script, with_deps=deps, python=args.get("python"), script_args=args.get("script_args"), ) # Use UV image unless overridden image = args.get("image", UV_DEFAULT_IMAGE) job_type = "Python" # Docker mode: command provided else: image = args.get("image", "python:3.12") job_type = "Docker" # Run the job job = await _async_call( self.api.run_job, image=image, command=command, env=_add_default_env(args.get("env")), secrets=_add_environment_variables(args.get("secrets"), self.hf_token), flavor=args.get("hardware_flavor", "cpu-basic"), timeout=args.get("timeout", "30m"), namespace=self.namespace, ) # Track job ID for cancellation on interrupt if self.session: self.session._running_job_ids.add(job.id) # Send job URL immediately after job creation (before waiting for completion) if self.session and self.tool_call_id: await self.session.send_event( Event( event_type="tool_state_change", data={ "tool_call_id": self.tool_call_id, "tool": "hf_jobs", "state": "running", "jobUrl": job.url, }, ) ) # Wait for completion and stream logs logger.info(f"{job_type} job started: {job.url}") logger.info("Streaming logs...") final_status, all_logs = await self._wait_for_job_completion( job_id=job.id, namespace=self.namespace, ) # Untrack job ID (completed or failed, no longer needs cancellation) if self.session: self.session._running_job_ids.discard(job.id) # Notify frontend of final status if self.session and self.tool_call_id: await self.session.send_event( Event( event_type="tool_state_change", data={ "tool_call_id": self.tool_call_id, "tool": "hf_jobs", "state": final_status.lower(), "jobUrl": job.url, }, ) ) # Filter out UV package installation output filtered_logs = _filter_uv_install_output(all_logs) # Format all logs for the agent log_text = _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)" response = f"""{job_type} job completed! **Job ID:** {job.id} **Final Status:** {final_status} **View at:** {job.url} **Logs:** ``` {log_text} ```""" return {"formatted": response, "totalResults": 1, "resultsShared": 1} except Exception as e: raise Exception(f"Failed to run job: {str(e)}") async def _list_jobs(self, args: Dict[str, Any]) -> ToolResult: """List jobs using HfApi.list_jobs()""" jobs_list = await _async_call(self.api.list_jobs, namespace=self.namespace) # Filter jobs if not args.get("all", False): jobs_list = [j for j in jobs_list if j.status.stage == "RUNNING"] if args.get("status"): status_filter = args["status"].upper() jobs_list = [j for j in jobs_list if status_filter in j.status.stage] # Convert JobInfo objects to dicts for formatting jobs_dicts = [_job_info_to_dict(j) for j in jobs_list] table = format_jobs_table(jobs_dicts) if len(jobs_list) == 0: if args.get("all", False): return { "formatted": "No jobs found.", "totalResults": 0, "resultsShared": 0, } return { "formatted": 'No running jobs found. Use `{"operation": "ps", "all": true}` to show all jobs.', "totalResults": 0, "resultsShared": 0, } response = f"**Jobs ({len(jobs_list)} total):**\n\n{table}" return { "formatted": response, "totalResults": len(jobs_list), "resultsShared": len(jobs_list), } async def _get_logs(self, args: Dict[str, Any]) -> ToolResult: """Fetch logs using HfApi.fetch_job_logs()""" job_id = args.get("job_id") if not job_id: return { "formatted": "job_id is required", "isError": True, "totalResults": 0, "resultsShared": 0, } try: # Fetch logs (returns generator, convert to list) logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=self.namespace) logs = await _async_call(list, logs_gen) if not logs: return { "formatted": f"No logs available for job {job_id}", "totalResults": 0, "resultsShared": 0, } log_text = _strip_ansi("\n".join(logs)) return { "formatted": f"**Logs for {job_id}:**\n\n```\n{log_text}\n```", "totalResults": 1, "resultsShared": 1, } except Exception as e: return { "formatted": f"Failed to fetch logs: {str(e)}", "isError": True, "totalResults": 0, "resultsShared": 0, } async def _inspect_job(self, args: Dict[str, Any]) -> ToolResult: """Inspect job using HfApi.inspect_job()""" job_id = args.get("job_id") if not job_id: return { "formatted": "job_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } job_ids = job_id if isinstance(job_id, list) else [job_id] jobs = [] for jid in job_ids: try: job = await _async_call( self.api.inspect_job, job_id=jid, namespace=self.namespace, ) jobs.append(_job_info_to_dict(job)) except Exception as e: raise Exception(f"Failed to inspect job {jid}: {str(e)}") formatted_details = format_job_details(jobs) response = f"**Job Details** ({len(jobs)} job{'s' if len(jobs) > 1 else ''}):\n\n{formatted_details}" return { "formatted": response, "totalResults": len(jobs), "resultsShared": len(jobs), } async def _cancel_job(self, args: Dict[str, Any]) -> ToolResult: """Cancel job using HfApi.cancel_job()""" job_id = args.get("job_id") if not job_id: return { "formatted": "job_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } await _async_call( self.api.cancel_job, job_id=job_id, namespace=self.namespace, ) response = f"""✓ Job {job_id} has been cancelled. To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}}`""" return {"formatted": response, "totalResults": 1, "resultsShared": 1} async def _scheduled_run(self, args: Dict[str, Any]) -> ToolResult: """Create scheduled job using HfApi.create_scheduled_job() - smart detection of Python vs Docker mode""" try: script = args.get("script") command = args.get("command") schedule = args.get("schedule") if not schedule: raise ValueError("schedule is required for scheduled jobs") # Validate mutually exclusive parameters if script and command: raise ValueError( "'script' and 'command' are mutually exclusive. Provide one or the other, not both." ) if not script and not command: raise ValueError( "Either 'script' (for Python) or 'command' (for Docker) must be provided." ) # Python mode: script provided if script: # Get dependencies and ensure hf-transfer is included deps = _ensure_hf_transfer_dependency(args.get("dependencies")) # Resolve the command based on script type command = _resolve_uv_command( script=script, with_deps=deps, python=args.get("python"), script_args=args.get("script_args"), ) # Use UV image unless overridden image = args.get("image", UV_DEFAULT_IMAGE) job_type = "Python" # Docker mode: command provided else: image = args.get("image", "python:3.12") job_type = "Docker" # Create scheduled job scheduled_job = await _async_call( self.api.create_scheduled_job, image=image, command=command, schedule=schedule, env=_add_default_env(args.get("env")), secrets=_add_environment_variables(args.get("secrets"), self.hf_token), flavor=args.get("hardware_flavor", "cpu-basic"), timeout=args.get("timeout", "30m"), namespace=self.namespace, ) scheduled_dict = _scheduled_job_info_to_dict(scheduled_job) response = f"""✓ Scheduled {job_type} job created successfully! **Scheduled Job ID:** {scheduled_dict["id"]} **Schedule:** {scheduled_dict["schedule"]} **Suspended:** {"Yes" if scheduled_dict.get("suspend") else "No"} **Next Run:** {scheduled_dict.get("nextRun", "N/A")} To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_job_id": "{scheduled_dict["id"]}"}}` To list all, call this tool with `{{"operation": "scheduled ps"}}`""" return {"formatted": response, "totalResults": 1, "resultsShared": 1} except Exception as e: raise Exception(f"Failed to create scheduled job: {str(e)}") async def _list_scheduled_jobs(self, args: Dict[str, Any]) -> ToolResult: """List scheduled jobs using HfApi.list_scheduled_jobs()""" scheduled_jobs_list = await _async_call( self.api.list_scheduled_jobs, namespace=self.namespace, ) # Filter jobs - default: hide suspended jobs unless --all is specified if not args.get("all", False): scheduled_jobs_list = [j for j in scheduled_jobs_list if not j.suspend] # Convert to dicts for formatting scheduled_dicts = [_scheduled_job_info_to_dict(j) for j in scheduled_jobs_list] table = format_scheduled_jobs_table(scheduled_dicts) if len(scheduled_jobs_list) == 0: if args.get("all", False): return { "formatted": "No scheduled jobs found.", "totalResults": 0, "resultsShared": 0, } return { "formatted": 'No active scheduled jobs found. Use `{"operation": "scheduled ps", "all": true}` to show suspended jobs.', "totalResults": 0, "resultsShared": 0, } response = f"**Scheduled Jobs ({len(scheduled_jobs_list)} total):**\n\n{table}" return { "formatted": response, "totalResults": len(scheduled_jobs_list), "resultsShared": len(scheduled_jobs_list), } async def _inspect_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: """Inspect scheduled job using HfApi.inspect_scheduled_job()""" scheduled_job_id = args.get("scheduled_job_id") if not scheduled_job_id: return { "formatted": "scheduled_job_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } scheduled_job = await _async_call( self.api.inspect_scheduled_job, scheduled_job_id=scheduled_job_id, namespace=self.namespace, ) scheduled_dict = _scheduled_job_info_to_dict(scheduled_job) formatted_details = format_scheduled_job_details(scheduled_dict) return { "formatted": f"**Scheduled Job Details:**\n\n{formatted_details}", "totalResults": 1, "resultsShared": 1, } async def _delete_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: """Delete scheduled job using HfApi.delete_scheduled_job()""" scheduled_job_id = args.get("scheduled_job_id") if not scheduled_job_id: return { "formatted": "scheduled_job_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } await _async_call( self.api.delete_scheduled_job, scheduled_job_id=scheduled_job_id, namespace=self.namespace, ) return { "formatted": f"✓ Scheduled job {scheduled_job_id} has been deleted.", "totalResults": 1, "resultsShared": 1, } async def _suspend_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: """Suspend scheduled job using HfApi.suspend_scheduled_job()""" scheduled_job_id = args.get("scheduled_job_id") if not scheduled_job_id: return { "formatted": "scheduled_job_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } await _async_call( self.api.suspend_scheduled_job, scheduled_job_id=scheduled_job_id, namespace=self.namespace, ) response = f"""✓ Scheduled job {scheduled_job_id} has been suspended. To resume, call this tool with `{{"operation": "scheduled resume", "scheduled_job_id": "{scheduled_job_id}"}}`""" return {"formatted": response, "totalResults": 1, "resultsShared": 1} async def _resume_scheduled_job(self, args: Dict[str, Any]) -> ToolResult: """Resume scheduled job using HfApi.resume_scheduled_job()""" scheduled_job_id = args.get("scheduled_job_id") if not scheduled_job_id: return { "formatted": "scheduled_job_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } await _async_call( self.api.resume_scheduled_job, scheduled_job_id=scheduled_job_id, namespace=self.namespace, ) response = f"""✓ Scheduled job {scheduled_job_id} has been resumed. To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_job_id": "{scheduled_job_id}"}}`""" return {"formatted": response, "totalResults": 1, "resultsShared": 1} # Tool specification for agent registration HF_JOBS_TOOL_SPEC = { "name": "hf_jobs", "description": ( "Execute Python scripts or Docker containers on HF cloud infrastructure.\n\n" "Two modes (mutually exclusive): Python mode (script + dependencies) or Docker mode (command + image). " "Provide exactly ONE of 'script' or 'command'.\n\n" "BEFORE submitting training/fine-tuning jobs:\n" "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. " "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n" "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n" "- Training config MUST include push_to_hub=True and hub_model_id. " "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n" "- Include trackio monitoring and provide the dashboard URL to the user.\n\n" "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. " "Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n" "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n" f"Hardware: CPU: {CPU_FLAVORS_DESC}. GPU: {GPU_FLAVORS_DESC}.\n" "Common picks: t4-small ($0.60/hr, 1-3B), a10g-large ($2/hr, 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+). " "Note: a10g-small and a10g-large have the SAME 24GB GPU — the difference is CPU/RAM only.\n\n" "OOM RECOVERY: When a training job fails with CUDA OOM:\n" "1. Reduce per_device_train_batch_size and increase gradient_accumulation_steps proportionally (keep effective batch size identical)\n" "2. Enable gradient_checkpointing=True\n" "3. Upgrade to larger GPU (a10g→a100→h100)\n" "Do NOT switch training methods (e.g. full SFT to LoRA) or reduce max_length — those change what the user gets and require explicit approval.\n\n" "Examples:\n" "Training: {'operation': 'run', 'script': '/app/train.py', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a100-large', 'timeout': '8h'}\n" "Monitor: {'operation': 'ps'}, {'operation': 'logs', 'job_id': 'xxx'}, {'operation': 'cancel', 'job_id': 'xxx'}" "Docker: {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2'], 'image': 'duckdb/duckdb', 'hardware_flavor': 'cpu-basic', 'timeout': '1h'}\n" ), "parameters": { "type": "object", "properties": { "operation": { "type": "string", "enum": [ "run", "ps", "logs", "inspect", "cancel", "scheduled run", "scheduled ps", "scheduled inspect", "scheduled delete", "scheduled suspend", "scheduled resume", ], "description": "Operation to execute.", }, "script": { "type": "string", "description": ( "Python code or sandbox file path (e.g. '/app/train.py') or URL. " "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. " "Mutually exclusive with 'command'." ), }, "dependencies": { "type": "array", "items": {"type": "string"}, "description": ( "Pip packages to install. Include ALL required packages. " "Common training set: ['transformers', 'trl', 'torch', 'datasets', 'trackio', 'accelerate']. " "Only used with 'script'." ), }, "image": { "type": "string", "description": "Docker image. Optional — auto-selected if not provided. Use with 'command'.", }, "command": { "type": "array", "items": {"type": "string"}, "description": "Command to execute as list. Triggers Docker mode. Mutually exclusive with 'script'.", }, "hardware_flavor": { "type": "string", "description": ( "Hardware type. Sizing guide: 1-3B params → t4-small/a10g-small, " "7-13B → a10g-large, 30B+ → a100-large, 70B+ → h100/h100x8. " f"All options: CPU: {CPU_FLAVORS}. GPU: {GPU_FLAVORS}." ), }, "timeout": { "type": "string", "description": ( "Maximum job runtime. MUST be >2h for any training job — default 30m kills training mid-run. " "Guidelines: 1-3B models: 3-4h, 7-13B: 6-8h, 30B+: 12-24h. " "Use 30m-1h only for quick data processing or inference tasks. Default: '30m'." ), }, "env": { "type": "object", "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.", }, "job_id": { "type": "string", "description": "Job ID. Required for: logs, inspect, cancel.", }, "scheduled_job_id": { "type": "string", "description": "Scheduled job ID. Required for: scheduled inspect/delete/suspend/resume.", }, "schedule": { "type": "string", "description": "Cron schedule or preset (@hourly, @daily, @weekly, @monthly). Required for: scheduled run.", }, }, "required": ["operation"], }, } async def hf_jobs_handler( arguments: Dict[str, Any], session: Any = None, tool_call_id: str | None = None ) -> tuple[str, bool]: """Handler for agent tool router""" try: async def log_callback(log: str): if session: await session.send_event( Event(event_type="tool_log", data={"tool": "hf_jobs", "log": log}) ) # If script is a sandbox file path, read it from the sandbox script = arguments.get("script", "") sandbox = getattr(session, "sandbox", None) if session else None if sandbox and script: from agent.tools.sandbox_tool import resolve_sandbox_script content, error = await resolve_sandbox_script(sandbox, script) if error: return error, False if content: arguments = {**arguments, "script": content} hf_token = session.hf_token if session else None namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None) tool = HfJobsTool( namespace=namespace, hf_token=hf_token, log_callback=log_callback if session else None, session=session, tool_call_id=tool_call_id, ) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error executing HF Jobs tool: {str(e)}", False ================================================ FILE: agent/tools/local_tools.py ================================================ """ Local tool implementations — bash/read/write/edit running on the user's machine. Drop-in replacement for sandbox tools when running in CLI (local) mode. Same tool specs (names, parameters) but handlers execute locally via subprocess/pathlib instead of going through a remote sandbox. """ from __future__ import annotations import os import re import subprocess import tempfile from pathlib import Path from typing import Any MAX_OUTPUT_CHARS = 25_000 MAX_LINE_LENGTH = 4000 DEFAULT_READ_LINES = 2000 DEFAULT_TIMEOUT = 120 MAX_TIMEOUT = 36000 # 10 hours — needed for long training runs (e.g. PostTrainBench) _ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07') # Track files that have been read this session (enforces read-before-write/edit) _files_read: set[str] = set() def _resolve_path(path: str) -> str: try: return str(Path(path).resolve()) except Exception: return path def _atomic_write(path: Path, content: str) -> None: """Write file atomically via temp file + os.replace(). Ensures the file is never left in a partial/corrupted state — it's either the old content or the new content, never half-written. """ path.parent.mkdir(parents=True, exist_ok=True) fd = None tmp_path = None try: fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp") os.write(fd, content.encode("utf-8")) os.fsync(fd) os.close(fd) fd = None os.replace(tmp_path, str(path)) tmp_path = None # successfully replaced, nothing to clean up finally: if fd is not None: os.close(fd) if tmp_path is not None: try: os.unlink(tmp_path) except OSError: pass def _strip_ansi(text: str) -> str: return _ANSI_RE.sub('', text) def _truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25) -> str: """Tail-biased truncation with temp file spillover for full output access.""" if len(output) <= max_chars: return output # Write full output to temp file so LLM can read specific sections spill_path = None try: with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', prefix='bash_output_', delete=False) as f: f.write(output) spill_path = f.name except Exception: pass head_budget = int(max_chars * head_ratio) tail_budget = max_chars - head_budget head = output[:head_budget] tail = output[-tail_budget:] total = len(output) omitted = total - max_chars meta = f"\n\n... ({omitted:,} of {total:,} chars omitted, showing first {head_budget:,} + last {tail_budget:,}) ...\n" if spill_path: meta += f"Full output saved to {spill_path} — use the read tool with offset/limit to inspect specific sections.\n" meta += "IMPORTANT: The command has finished. Analyze the output above and continue with your next action.\n" return head + meta + tail # ── Handlers ──────────────────────────────────────────────────────────── async def _bash_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: command = args.get("command", "") if not command: return "No command provided.", False work_dir = args.get("work_dir", ".") timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT) try: result = subprocess.run( command, shell=True, capture_output=True, text=True, cwd=work_dir, timeout=timeout, ) output = _strip_ansi(result.stdout + result.stderr) output = _truncate_output(output) if not output.strip(): output = "(no output)" return output, result.returncode == 0 except subprocess.TimeoutExpired: return ( f"Command timed out after {timeout}s and was killed.\n\n" f"For long-running commands, run in the background and poll:\n" f" nohup > /tmp/output.log 2>&1 & echo $!\n" f"Then check status with:\n" f" kill -0 2>/dev/null && echo 'running' || echo 'done'\n" f" tail -n 50 /tmp/output.log" ), False except Exception as e: return f"bash error: {e}", False async def _read_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: file_path = args.get("path", "") if not file_path: return "No path provided.", False p = Path(file_path) if not p.exists(): return f"File not found: {file_path}", False if p.is_dir(): return "Cannot read a directory. Use bash with 'ls' instead.", False try: raw_content = p.read_text() except Exception as e: return f"read error: {e}", False _files_read.add(_resolve_path(file_path)) lines = raw_content.splitlines() offset = max((args.get("offset") or 1), 1) limit = args.get("limit") or DEFAULT_READ_LINES selected = lines[offset - 1 : offset - 1 + limit] numbered = [] for i, line in enumerate(selected, start=offset): if len(line) > MAX_LINE_LENGTH: line = line[:MAX_LINE_LENGTH] + "..." numbered.append(f"{i:>6}\t{line}") return "\n".join(numbered), True async def _write_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: file_path = args.get("path", "") content = args.get("content", "") if not file_path: return "No path provided.", False p = Path(file_path) if p.exists() and _resolve_path(file_path) not in _files_read: return ( f"You must read {file_path} before overwriting it. " f"Use the read tool first to see current contents." ), False try: _atomic_write(p, content) _files_read.add(_resolve_path(file_path)) msg = f"Wrote {len(content)} bytes to {file_path}" # Syntax validation for Python files if p.suffix == ".py": from agent.tools.edit_utils import validate_python warnings = validate_python(content, file_path) if warnings: msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings) return msg, True except Exception as e: return f"write error: {e}", False async def _edit_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: from agent.tools.edit_utils import apply_edit, validate_python file_path = args.get("path", "") old_str = args.get("old_str", "") new_str = args.get("new_str", "") replace_all = args.get("replace_all", False) mode = args.get("mode", "replace") if not file_path: return "No path provided.", False if old_str == new_str: return "old_str and new_str must differ.", False p = Path(file_path) if not p.exists(): return f"File not found: {file_path}", False if _resolve_path(file_path) not in _files_read: return ( f"You must read {file_path} before editing it. " f"Use the read tool first to see current contents." ), False try: text = p.read_text() except Exception as e: return f"edit read error: {e}", False try: new_text, replacements, fuzzy_note = apply_edit( text, old_str, new_str, mode=mode, replace_all=replace_all ) except ValueError as e: return str(e), False try: _atomic_write(p, new_text) except Exception as e: return f"edit write error: {e}", False msg = f"Edited {file_path} ({replacements} replacement{'s' if replacements > 1 else ''})" if fuzzy_note: msg += f" {fuzzy_note}" # Syntax validation for Python files if p.suffix == ".py": warnings = validate_python(new_text, file_path) if warnings: msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings) return msg, True # ── Local tool specs (override sandbox /app references) ──────────────── _LOCAL_TOOL_SPECS = { "bash": { "description": ( "Run a shell command on the local machine and return stdout/stderr.\n" "\n" "IMPORTANT: Do NOT use bash for file operations — use the dedicated tools instead:\n" "- To read files: use read (not cat/head/tail)\n" "- To edit files: use edit (not sed/awk)\n" "- To write files: use write (not echo/cat < > /tmp/output.log 2>&1 & echo $!\n" "Then check status:\n" " kill -0 2>/dev/null && echo 'running' || echo 'done'\n" " tail -n 50 /tmp/output.log\n" "\n" "Timeout default 120s, max 36000s." ), "parameters": { "type": "object", "required": ["command"], "additionalProperties": False, "properties": { "command": { "type": "string", "description": "The shell command to execute.", }, "description": { "type": "string", "description": "Short description (5-10 words, active voice).", }, "work_dir": { "type": "string", "description": "Working directory (default: current directory).", }, "timeout": { "type": "integer", "description": "Optional timeout in seconds (default: 120, max: 36000).", }, }, }, }, "read": { "description": ( "Reads a file from the local filesystem. Returns contents with line numbers " "(cat -n format).\n" "\n" "Usage:\n" "- By default, reads up to 2000 lines from the beginning of the file.\n" "- You can optionally specify offset and limit for large files, but prefer " "reading the whole file first.\n" "- Lines longer than 4000 chars are truncated.\n" "- Cannot read directories — use bash with 'ls' instead.\n" "- You should read multiple potentially useful files in parallel when possible.\n" "- IMPORTANT: Always read a file before editing or overwriting it. The edit and " "write tools will reject operations on files you haven't read." ), "parameters": { "type": "object", "required": ["path"], "additionalProperties": False, "properties": { "path": { "type": "string", "description": "Absolute path to the file to read.", }, "offset": { "type": "integer", "description": "The line number to start reading from (1-based). Only provide if the file is too large to read at once.", }, "limit": { "type": "integer", "description": "The number of lines to read. Only provide if the file is too large to read at once.", }, }, }, }, "write": { "description": ( "Writes a file to the local filesystem. Overwrites the existing file if one " "exists at the path.\n" "\n" "- If this is an existing file, you MUST use the read tool first. This tool " "will fail if you did not read the file first.\n" "- ALWAYS prefer editing existing files with the edit tool over overwriting " "with write.\n" "- Creates parent directories as needed." ), "parameters": { "type": "object", "required": ["path", "content"], "additionalProperties": False, "properties": { "path": { "type": "string", "description": "Absolute path to the file to write.", }, "content": { "type": "string", "description": "The complete file content to write.", }, }, }, }, "edit": { "description": ( "Performs string replacements in files. Supports exact matching with " "fuzzy fallback.\n" "\n" "Usage:\n" "- You must read the file at least once before editing. This tool will " "error if you attempt an edit without reading the file.\n" "- The edit will FAIL if old_str is not unique in the file. Either provide " "a larger string with more surrounding context to make it unique, or set " "replace_all to true.\n" "- old_str and new_str must differ.\n" "- Preserve indentation exactly as it appears in the file.\n" "- Do NOT include line number prefixes from read output in old_str or new_str.\n" "- To delete code, set new_str to empty string.\n" "- Use replace_all for renaming variables or strings across the file.\n" "\n" "Modes:\n" "- replace (default): replace first occurrence of old_str with new_str.\n" "- append_after: insert new_str immediately after old_str (old_str is kept).\n" "- prepend_before: insert new_str immediately before old_str (old_str is kept)." ), "parameters": { "type": "object", "required": ["path", "old_str", "new_str"], "additionalProperties": False, "properties": { "path": { "type": "string", "description": "Absolute path to the file to edit.", }, "old_str": { "type": "string", "description": "The text to find in the file. Must match exactly (fuzzy matching is used as fallback).", }, "new_str": { "type": "string", "description": "The replacement text. For append_after/prepend_before modes, the text to insert.", }, "replace_all": { "type": "boolean", "description": "Replace all occurrences of old_str (default: false).", "default": False, }, "mode": { "type": "string", "enum": ["replace", "append_after", "prepend_before"], "description": "Edit mode (default: replace).", "default": "replace", }, }, }, }, } _HANDLERS = { "bash": _bash_handler, "read": _read_handler, "write": _write_handler, "edit": _edit_handler, } def get_local_tools(): """Return local ToolSpecs for bash/read/write/edit (no sandbox_create).""" from agent.core.tools import ToolSpec tools = [] for name, spec in _LOCAL_TOOL_SPECS.items(): handler = _HANDLERS.get(name) if handler is None: continue tools.append( ToolSpec( name=name, description=spec["description"], parameters=spec["parameters"], handler=handler, ) ) return tools ================================================ FILE: agent/tools/papers_tool.py ================================================ """ HF Papers Tool — Discover papers, read their contents, and find linked resources. Operations: trending, search, paper_details, read_paper, find_datasets, find_models, find_collections, find_all_resources, citation_graph, snippet_search, recommend """ import asyncio import os import re import time from typing import Any import httpx from bs4 import BeautifulSoup, Tag from agent.tools.types import ToolResult HF_API = "https://huggingface.co/api" ARXIV_HTML = "https://arxiv.org/html" AR5IV_HTML = "https://ar5iv.labs.arxiv.org/html" DEFAULT_LIMIT = 10 MAX_LIMIT = 50 MAX_SUMMARY_LEN = 300 MAX_SECTION_PREVIEW_LEN = 280 MAX_SECTION_TEXT_LEN = 8000 SORT_MAP = { "downloads": "downloads", "likes": "likes", "trending": "trendingScore", } # --------------------------------------------------------------------------- # Semantic Scholar API # --------------------------------------------------------------------------- S2_API = "https://api.semanticscholar.org" S2_API_KEY = os.environ.get("S2_API_KEY") S2_HEADERS: dict[str, str] = {"x-api-key": S2_API_KEY} if S2_API_KEY else {} S2_TIMEOUT = 12 _s2_last_request: float = 0.0 # Shared response cache (survives across sessions, keyed by (path, params_tuple)) _s2_cache: dict[str, Any] = {} _S2_CACHE_MAX = 500 def _s2_paper_id(arxiv_id: str) -> str: """Convert bare arxiv ID to S2 format.""" return f"ARXIV:{arxiv_id}" def _s2_cache_key(path: str, params: dict | None) -> str: """Build a hashable cache key from path + sorted params.""" p = tuple(sorted((params or {}).items())) return f"{path}:{p}" async def _s2_request( client: httpx.AsyncClient, method: str, path: str, **kwargs: Any, ) -> httpx.Response | None: """S2 request with 2 retries on 429/5xx. Rate-limited only when using API key.""" global _s2_last_request url = f"{S2_API}{path}" kwargs.setdefault("headers", {}).update(S2_HEADERS) kwargs.setdefault("timeout", S2_TIMEOUT) for attempt in range(3): # Rate limit only when authenticated (1 req/s for search, 10 req/s for others) if S2_API_KEY: min_interval = 1.0 if "search" in path else 0.1 elapsed = time.monotonic() - _s2_last_request if elapsed < min_interval: await asyncio.sleep(min_interval - elapsed) _s2_last_request = time.monotonic() try: resp = await client.request(method, url, **kwargs) if resp.status_code == 429: if attempt < 2: await asyncio.sleep(60) continue return None if resp.status_code >= 500: if attempt < 2: await asyncio.sleep(3) continue return None return resp except (httpx.RequestError, httpx.HTTPStatusError): if attempt < 2: await asyncio.sleep(3) continue return None return None async def _s2_get_json( client: httpx.AsyncClient, path: str, params: dict | None = None, ) -> dict | None: """Cached S2 GET returning parsed JSON or None.""" key = _s2_cache_key(path, params) if key in _s2_cache: return _s2_cache[key] resp = await _s2_request(client, "GET", path, params=params or {}) if resp and resp.status_code == 200: data = resp.json() if len(_s2_cache) < _S2_CACHE_MAX: _s2_cache[key] = data return data return None async def _s2_get_paper( client: httpx.AsyncClient, arxiv_id: str, fields: str, ) -> dict | None: """Fetch a single paper from S2 by arxiv ID. Returns None on failure.""" return await _s2_get_json( client, f"/graph/v1/paper/{_s2_paper_id(arxiv_id)}", {"fields": fields}, ) # --------------------------------------------------------------------------- # HTML paper parsing # --------------------------------------------------------------------------- def _parse_paper_html(html: str) -> dict[str, Any]: """Parse arxiv HTML into structured sections. Returns: { "title": str, "abstract": str, "sections": [{"id": str, "title": str, "level": int, "text": str}], } """ soup = BeautifulSoup(html, "html.parser") # Title title_el = soup.find("h1", class_="ltx_title") title = title_el.get_text(strip=True).removeprefix("Title:") if title_el else "" # Abstract abstract_el = soup.find("div", class_="ltx_abstract") abstract = "" if abstract_el: # Skip the "Abstract" heading itself for child in abstract_el.children: if isinstance(child, Tag) and child.name in ("h6", "h2", "h3", "p", "span"): if child.get_text(strip=True).lower() == "abstract": continue if isinstance(child, Tag) and child.name == "p": abstract += child.get_text(separator=" ", strip=True) + " " abstract = abstract.strip() # Sections — collect h2/h3 headings and text between them sections: list[dict[str, Any]] = [] headings = soup.find_all(["h2", "h3"], class_=lambda c: c and "ltx_title" in c) for heading in headings: level = 2 if heading.name == "h2" else 3 heading_text = heading.get_text(separator=" ", strip=True) # Collect text from siblings until next heading of same or higher level text_parts: list[str] = [] sibling = heading.find_next_sibling() while sibling: if isinstance(sibling, Tag): if sibling.name in ("h2", "h3") and "ltx_title" in ( sibling.get("class") or [] ): break # Also stop at h2 if we're collecting h3 content if sibling.name == "h2" and level == 3: break text_parts.append(sibling.get_text(separator=" ", strip=True)) sibling = sibling.find_next_sibling() # Also check parent section element for contained paragraphs parent_section = heading.find_parent("section") if parent_section and not text_parts: for p in parent_section.find_all("p", recursive=False): text_parts.append(p.get_text(separator=" ", strip=True)) section_text = "\n\n".join(t for t in text_parts if t) # Extract section number from heading text (e.g., "4 Experiments" → "4") num_match = re.match(r"^([A-Z]?\d+(?:\.\d+)*)\s", heading_text) section_id = num_match.group(1) if num_match else "" sections.append( { "id": section_id, "title": heading_text, "level": level, "text": section_text, } ) return {"title": title, "abstract": abstract, "sections": sections} def _find_section(sections: list[dict], query: str) -> dict | None: """Find a section by number or name (fuzzy).""" query_lower = query.lower().strip() # Exact match on section number for s in sections: if s["id"] == query_lower or s["id"] == query: return s # Exact match on title for s in sections: if query_lower == s["title"].lower(): return s # Substring match on title for s in sections: if query_lower in s["title"].lower(): return s # Number prefix match (e.g., "4" matches "4.1", "4.2", etc. — return parent) for s in sections: if s["id"].startswith(query_lower + ".") or s["id"] == query_lower: return s return None # --------------------------------------------------------------------------- # Formatting helpers # --------------------------------------------------------------------------- def _clean_description(text: str) -> str: """Strip HTML card artifacts and collapse whitespace from HF API descriptions.""" text = re.sub(r"[\t]+", " ", text) text = re.sub(r"\n{2,}", "\n", text) return text.strip() def _truncate(text: str, max_len: int) -> str: if len(text) <= max_len: return text return text[:max_len] + "..." def _format_paper_list( papers: list, title: str, date: str | None = None, query: str | None = None ) -> str: lines = [f"# {title}"] if date: lines[0] += f" ({date})" if query: lines.append(f"Filtered by: '{query}'") lines.append(f"Showing {len(papers)} paper(s)\n") for i, item in enumerate(papers, 1): paper = item.get("paper", item) arxiv_id = paper.get("id", "") paper_title = paper.get("title", "Unknown") upvotes = paper.get("upvotes", 0) summary = paper.get("ai_summary") or _truncate( paper.get("summary", ""), MAX_SUMMARY_LEN ) keywords = paper.get("ai_keywords") or [] github = paper.get("githubRepo") or "" stars = paper.get("githubStars") or 0 lines.append(f"## {i}. {paper_title}") lines.append(f"**arxiv_id:** {arxiv_id} | **upvotes:** {upvotes}") lines.append(f"https://huggingface.co/papers/{arxiv_id}") if keywords: lines.append(f"**Keywords:** {', '.join(keywords[:5])}") if github: lines.append(f"**GitHub:** {github} ({stars} stars)") if summary: lines.append(f"**Summary:** {_truncate(summary, MAX_SUMMARY_LEN)}") lines.append("") return "\n".join(lines) def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str: arxiv_id = paper.get("id", "") title = paper.get("title", "Unknown") upvotes = paper.get("upvotes", 0) ai_summary = paper.get("ai_summary") or "" summary = paper.get("summary", "") keywords = paper.get("ai_keywords") or [] github = paper.get("githubRepo") or "" stars = paper.get("githubStars") or 0 authors = paper.get("authors") or [] lines = [f"# {title}"] meta_parts = [f"**arxiv_id:** {arxiv_id}", f"**upvotes:** {upvotes}"] if s2_data: cites = s2_data.get("citationCount", 0) influential = s2_data.get("influentialCitationCount", 0) meta_parts.append(f"**citations:** {cites} ({influential} influential)") lines.append(" | ".join(meta_parts)) lines.append(f"https://huggingface.co/papers/{arxiv_id}") lines.append(f"https://arxiv.org/abs/{arxiv_id}") if authors: names = [a.get("name", "") for a in authors[:10]] author_str = ", ".join(n for n in names if n) if len(authors) > 10: author_str += f" (+{len(authors) - 10} more)" lines.append(f"**Authors:** {author_str}") if keywords: lines.append(f"**Keywords:** {', '.join(keywords)}") if s2_data and s2_data.get("s2FieldsOfStudy"): fields = [f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category")] if fields: lines.append(f"**Fields:** {', '.join(fields)}") if s2_data and s2_data.get("venue"): lines.append(f"**Venue:** {s2_data['venue']}") if github: lines.append(f"**GitHub:** {github} ({stars} stars)") if s2_data and s2_data.get("tldr"): tldr_text = s2_data["tldr"].get("text", "") if tldr_text: lines.append(f"\n## TL;DR\n{tldr_text}") if ai_summary: lines.append(f"\n## AI Summary\n{ai_summary}") if summary: lines.append(f"\n## Abstract\n{_truncate(summary, 500)}") lines.append( "\n**Next:** Use read_paper to read specific sections, find_all_resources for linked datasets/models, " "or citation_graph to trace references and citations." ) return "\n".join(lines) def _format_read_paper_toc(parsed: dict[str, Any], arxiv_id: str) -> str: """Format TOC view: abstract + section list with previews.""" lines = [f"# {parsed['title']}"] lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") if parsed["abstract"]: lines.append(f"## Abstract\n{parsed['abstract']}\n") lines.append("## Sections") for s in parsed["sections"]: prefix = " " if s["level"] == 3 else "" preview = ( _truncate(s["text"], MAX_SECTION_PREVIEW_LEN) if s["text"] else "(empty)" ) lines.append(f"{prefix}- **{s['title']}**: {preview}") lines.append( '\nCall read_paper with section parameter (e.g. section="4" or section="Experiments") to read a specific section.' ) return "\n".join(lines) def _format_read_paper_section(section: dict, arxiv_id: str) -> str: """Format a single section's full text.""" lines = [f"# {section['title']}"] lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") text = section["text"] if len(text) > MAX_SECTION_TEXT_LEN: text = ( text[:MAX_SECTION_TEXT_LEN] + f"\n\n... (truncated at {MAX_SECTION_TEXT_LEN} chars)" ) lines.append(text if text else "(This section has no extractable text content.)") return "\n".join(lines) def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str: lines = [f"# Datasets linked to paper {arxiv_id}"] lines.append(f"https://huggingface.co/papers/{arxiv_id}") lines.append(f"Showing {len(datasets)} dataset(s), sorted by {sort}\n") for i, ds in enumerate(datasets, 1): ds_id = ds.get("id", "unknown") downloads = ds.get("downloads", 0) likes = ds.get("likes", 0) desc = _truncate(_clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN) tags = ds.get("tags") or [] interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5] lines.append(f"**{i}. [{ds_id}](https://huggingface.co/datasets/{ds_id})**") lines.append(f" Downloads: {downloads:,} | Likes: {likes}") if interesting: lines.append(f" Tags: {', '.join(interesting)}") if desc: lines.append(f" {desc}") lines.append("") if datasets: top = datasets[0].get("id", "") lines.append(f'**Inspect top dataset:** hf_inspect_dataset(dataset="{top}")') return "\n".join(lines) def _format_datasets_compact(datasets: list) -> str: if not datasets: return "## Datasets\nNone found" lines = [f"## Datasets ({len(datasets)})"] for ds in datasets: lines.append( f"- **{ds.get('id', '?')}** ({ds.get('downloads', 0):,} downloads)" ) return "\n".join(lines) def _format_models(models: list, arxiv_id: str, sort: str) -> str: lines = [f"# Models linked to paper {arxiv_id}"] lines.append(f"https://huggingface.co/papers/{arxiv_id}") lines.append(f"Showing {len(models)} model(s), sorted by {sort}\n") for i, m in enumerate(models, 1): model_id = m.get("id", "unknown") downloads = m.get("downloads", 0) likes = m.get("likes", 0) pipeline = m.get("pipeline_tag") or "" library = m.get("library_name") or "" lines.append(f"**{i}. [{model_id}](https://huggingface.co/{model_id})**") meta = f" Downloads: {downloads:,} | Likes: {likes}" if pipeline: meta += f" | Task: {pipeline}" if library: meta += f" | Library: {library}" lines.append(meta) lines.append("") return "\n".join(lines) def _format_models_compact(models: list) -> str: if not models: return "## Models\nNone found" lines = [f"## Models ({len(models)})"] for m in models: pipeline = m.get("pipeline_tag") or "" suffix = f" ({pipeline})" if pipeline else "" lines.append( f"- **{m.get('id', '?')}** ({m.get('downloads', 0):,} downloads){suffix}" ) return "\n".join(lines) def _format_collections(collections: list, arxiv_id: str) -> str: lines = [f"# Collections containing paper {arxiv_id}"] lines.append(f"Showing {len(collections)} collection(s)\n") for i, c in enumerate(collections, 1): slug = c.get("slug", "") title = c.get("title", "Untitled") upvotes = c.get("upvotes", 0) owner = c.get("owner", {}).get("name", "") desc = _truncate(c.get("description") or "", MAX_SUMMARY_LEN) num_items = len(c.get("items", [])) lines.append(f"**{i}. {title}**") lines.append(f" By: {owner} | Upvotes: {upvotes} | Items: {num_items}") lines.append(f" https://huggingface.co/collections/{slug}") if desc: lines.append(f" {desc}") lines.append("") return "\n".join(lines) def _format_collections_compact(collections: list) -> str: if not collections: return "## Collections\nNone found" lines = [f"## Collections ({len(collections)})"] for c in collections: title = c.get("title", "Untitled") owner = c.get("owner", {}).get("name", "") upvotes = c.get("upvotes", 0) lines.append(f"- **{title}** by {owner} ({upvotes} upvotes)") return "\n".join(lines) # --------------------------------------------------------------------------- # Operation handlers # --------------------------------------------------------------------------- def _error(message: str) -> ToolResult: return { "formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True, } def _validate_arxiv_id(args: dict) -> str | None: """Return arxiv_id or None if missing.""" return args.get("arxiv_id") async def _op_trending(args: dict[str, Any], limit: int) -> ToolResult: date = args.get("date") query = args.get("query") params: dict[str, Any] = {"limit": limit if not query else max(limit * 3, 30)} if date: params["date"] = date async with httpx.AsyncClient(timeout=15) as client: resp = await client.get(f"{HF_API}/daily_papers", params=params) resp.raise_for_status() papers = resp.json() if query: q = query.lower() papers = [ p for p in papers if q in p.get("title", "").lower() or q in p.get("paper", {}).get("title", "").lower() or q in p.get("paper", {}).get("summary", "").lower() or any( q in kw.lower() for kw in (p.get("paper", {}).get("ai_keywords") or []) ) ] papers = papers[:limit] if not papers: msg = "No trending papers found" if query: msg += f" matching '{query}'" if date: msg += f" for {date}" return {"formatted": msg, "totalResults": 0, "resultsShared": 0} formatted = _format_paper_list(papers, "Trending Papers", date=date, query=query) return { "formatted": formatted, "totalResults": len(papers), "resultsShared": len(papers), } def _format_s2_paper_list(papers: list[dict], title: str) -> str: """Format a list of S2 paper results.""" lines = [f"# {title}"] lines.append(f"Showing {len(papers)} result(s)\n") for i, paper in enumerate(papers, 1): ptitle = paper.get("title") or "(untitled)" year = paper.get("year") or "?" cites = paper.get("citationCount", 0) venue = paper.get("venue") or "" ext_ids = paper.get("externalIds") or {} aid = ext_ids.get("ArXiv", "") tldr = (paper.get("tldr") or {}).get("text", "") lines.append(f"### {i}. {ptitle}") meta = [f"Year: {year}", f"Citations: {cites}"] if venue: meta.append(f"Venue: {venue}") if aid: meta.append(f"arxiv_id: {aid}") lines.append(" | ".join(meta)) if aid: lines.append(f"https://arxiv.org/abs/{aid}") if tldr: lines.append(f"**TL;DR:** {tldr}") lines.append("") lines.append("Use paper_details with arxiv_id for full info, or read_paper to read sections.") return "\n".join(lines) async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolResult | None: """Search via S2 bulk endpoint with filters. Returns None on failure.""" params: dict[str, Any] = { "query": query, "limit": limit, "fields": "title,externalIds,year,citationCount,tldr,venue,publicationDate", } # Date filter date_from = args.get("date_from", "") date_to = args.get("date_to", "") if date_from or date_to: params["publicationDateOrYear"] = f"{date_from}:{date_to}" # Fields of study categories = args.get("categories") if categories: params["fieldsOfStudy"] = categories # Min citations min_cites = args.get("min_citations") if min_cites: params["minCitationCount"] = str(min_cites) # Sort sort_by = args.get("sort_by") if sort_by and sort_by != "relevance": params["sort"] = f"{sort_by}:desc" async with httpx.AsyncClient(timeout=15) as client: resp = await _s2_request(client, "GET", "/graph/v1/paper/search/bulk", params=params) if not resp or resp.status_code != 200: return None data = resp.json() papers = data.get("data") or [] if not papers: return { "formatted": f"No papers found for '{query}' with the given filters.", "totalResults": 0, "resultsShared": 0, } formatted = _format_s2_paper_list(papers[:limit], f"Papers matching '{query}' (Semantic Scholar)") return { "formatted": formatted, "totalResults": data.get("total", len(papers)), "resultsShared": min(limit, len(papers)), } async def _op_search(args: dict[str, Any], limit: int) -> ToolResult: query = args.get("query") if not query: return _error("'query' is required for search operation.") # Route to S2 when filters are present use_s2 = any(args.get(k) for k in ("date_from", "date_to", "categories", "min_citations", "sort_by")) if use_s2: result = await _s2_bulk_search(query, args, limit) if result is not None: return result # Fall back to HF search (without filters) if S2 fails async with httpx.AsyncClient(timeout=15) as client: resp = await client.get( f"{HF_API}/papers/search", params={"q": query, "limit": limit} ) resp.raise_for_status() papers = resp.json() if not papers: return { "formatted": f"No papers found for '{query}'", "totalResults": 0, "resultsShared": 0, } formatted = _format_paper_list(papers, f"Papers matching '{query}'") return { "formatted": formatted, "totalResults": len(papers), "resultsShared": len(papers), } async def _op_paper_details(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for paper_details.") async with httpx.AsyncClient(timeout=15) as client: resp = await client.get(f"{HF_API}/papers/{arxiv_id}") resp.raise_for_status() paper = resp.json() return { "formatted": _format_paper_detail(paper), "totalResults": 1, "resultsShared": 1, } async def _op_read_paper(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for read_paper.") section_query = args.get("section") # Try fetching HTML from arxiv, then ar5iv, then fallback to abstract parsed = None async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: for base_url in [ARXIV_HTML, AR5IV_HTML]: try: resp = await client.get(f"{base_url}/{arxiv_id}") if resp.status_code == 200: parsed = _parse_paper_html(resp.text) if parsed["sections"]: # Only use if we got real sections break parsed = None except httpx.RequestError: continue # Fallback: return abstract from HF API if not parsed or not parsed["sections"]: try: async with httpx.AsyncClient(timeout=15) as client: resp = await client.get(f"{HF_API}/papers/{arxiv_id}") resp.raise_for_status() paper = resp.json() abstract = paper.get("summary", "") title = paper.get("title", "") msg = f"# {title}\nhttps://arxiv.org/abs/{arxiv_id}\n\n" msg += f"## Abstract\n{abstract}\n\n" msg += "HTML version not available for this paper. Only abstract shown.\n" msg += f"PDF: https://arxiv.org/pdf/{arxiv_id}" return {"formatted": msg, "totalResults": 1, "resultsShared": 1} except Exception: return _error( f"Could not fetch paper {arxiv_id}. Check the arxiv ID is correct." ) # Return TOC or specific section if not section_query: formatted = _format_read_paper_toc(parsed, arxiv_id) return { "formatted": formatted, "totalResults": len(parsed["sections"]), "resultsShared": len(parsed["sections"]), } section = _find_section(parsed["sections"], section_query) if not section: available = "\n".join(f"- {s['title']}" for s in parsed["sections"]) return _error( f"Section '{section_query}' not found. Available sections:\n{available}" ) formatted = _format_read_paper_section(section, arxiv_id) return {"formatted": formatted, "totalResults": 1, "resultsShared": 1} # --------------------------------------------------------------------------- # Citation graph (Semantic Scholar) # --------------------------------------------------------------------------- def _format_citation_entry(entry: dict, show_context: bool = False) -> str: """Format a single citation/reference entry.""" paper = entry.get("citingPaper") or entry.get("citedPaper") or {} title = paper.get("title") or "(untitled)" year = paper.get("year") or "?" cites = paper.get("citationCount", 0) ext_ids = paper.get("externalIds") or {} aid = ext_ids.get("ArXiv", "") influential = " **[influential]**" if entry.get("isInfluential") else "" parts = [f"- **{title}** ({year}, {cites} cites){influential}"] if aid: parts[0] += f" arxiv:{aid}" if show_context: intents = entry.get("intents") or [] if intents: parts.append(f" Intent: {', '.join(intents)}") contexts = entry.get("contexts") or [] for ctx in contexts[:2]: if ctx: parts.append(f" > {_truncate(ctx, 200)}") return "\n".join(parts) def _format_citation_graph( arxiv_id: str, references: list[dict] | None, citations: list[dict] | None, ) -> str: lines = [f"# Citation Graph for {arxiv_id}"] lines.append(f"https://arxiv.org/abs/{arxiv_id}\n") if references is not None: lines.append(f"## References ({len(references)})") if references: for entry in references: lines.append(_format_citation_entry(entry)) else: lines.append("No references found.") lines.append("") if citations is not None: lines.append(f"## Citations ({len(citations)})") if citations: for entry in citations: lines.append(_format_citation_entry(entry, show_context=True)) else: lines.append("No citations found.") lines.append("") lines.append("**Tip:** Use paper_details with an arxiv_id from above to explore further.") return "\n".join(lines) async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for citation_graph.") direction = args.get("direction", "both") s2_id = _s2_paper_id(arxiv_id) fields = "title,externalIds,year,citationCount,influentialCitationCount,contexts,intents,isInfluential" params = {"fields": fields, "limit": limit} async with httpx.AsyncClient(timeout=15) as client: refs, cites = None, None coros = [] if direction in ("references", "both"): coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params)) if direction in ("citations", "both"): coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params)) results = await asyncio.gather(*coros, return_exceptions=True) idx = 0 if direction in ("references", "both"): r = results[idx] if isinstance(r, dict): refs = r.get("data", []) idx += 1 if direction in ("citations", "both"): r = results[idx] if isinstance(r, dict): cites = r.get("data", []) if refs is None and cites is None: return _error(f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar.") total = (len(refs) if refs else 0) + (len(cites) if cites else 0) return { "formatted": _format_citation_graph(arxiv_id, refs, cites), "totalResults": total, "resultsShared": total, } async def _op_find_datasets(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for find_datasets.") sort = args.get("sort", "downloads") sort_key = SORT_MAP.get(sort, "downloads") async with httpx.AsyncClient(timeout=15) as client: resp = await client.get( f"{HF_API}/datasets", params={ "filter": f"arxiv:{arxiv_id}", "limit": limit, "sort": sort_key, "direction": -1, }, ) resp.raise_for_status() datasets = resp.json() if not datasets: return { "formatted": f"No datasets found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", "totalResults": 0, "resultsShared": 0, } return { "formatted": _format_datasets(datasets, arxiv_id, sort), "totalResults": len(datasets), "resultsShared": len(datasets), } async def _op_find_models(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for find_models.") sort = args.get("sort", "downloads") sort_key = SORT_MAP.get(sort, "downloads") async with httpx.AsyncClient(timeout=15) as client: resp = await client.get( f"{HF_API}/models", params={ "filter": f"arxiv:{arxiv_id}", "limit": limit, "sort": sort_key, "direction": -1, }, ) resp.raise_for_status() models = resp.json() if not models: return { "formatted": f"No models found linked to paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", "totalResults": 0, "resultsShared": 0, } return { "formatted": _format_models(models, arxiv_id, sort), "totalResults": len(models), "resultsShared": len(models), } async def _op_find_collections(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for find_collections.") async with httpx.AsyncClient(timeout=15) as client: resp = await client.get(f"{HF_API}/collections", params={"paper": arxiv_id}) resp.raise_for_status() collections = resp.json() if not collections: return { "formatted": f"No collections found containing paper {arxiv_id}.\nhttps://huggingface.co/papers/{arxiv_id}", "totalResults": 0, "resultsShared": 0, } collections = collections[:limit] return { "formatted": _format_collections(collections, arxiv_id), "totalResults": len(collections), "resultsShared": len(collections), } async def _op_find_all_resources(args: dict[str, Any], limit: int) -> ToolResult: arxiv_id = _validate_arxiv_id(args) if not arxiv_id: return _error("'arxiv_id' is required for find_all_resources.") per_cat = min(limit, 10) async with httpx.AsyncClient(timeout=15) as client: results = await asyncio.gather( client.get( f"{HF_API}/datasets", params={ "filter": f"arxiv:{arxiv_id}", "limit": per_cat, "sort": "downloads", "direction": -1, }, ), client.get( f"{HF_API}/models", params={ "filter": f"arxiv:{arxiv_id}", "limit": per_cat, "sort": "downloads", "direction": -1, }, ), client.get(f"{HF_API}/collections", params={"paper": arxiv_id}), return_exceptions=True, ) sections = [] total = 0 # Datasets if isinstance(results[0], Exception): sections.append(f"## Datasets\nError: {results[0]}") else: datasets = results[0].json() total += len(datasets) sections.append(_format_datasets_compact(datasets[:per_cat])) # Models if isinstance(results[1], Exception): sections.append(f"## Models\nError: {results[1]}") else: models = results[1].json() total += len(models) sections.append(_format_models_compact(models[:per_cat])) # Collections if isinstance(results[2], Exception): sections.append(f"## Collections\nError: {results[2]}") else: collections = results[2].json() total += len(collections) sections.append(_format_collections_compact(collections[:per_cat])) header = f"# Resources linked to paper {arxiv_id}\nhttps://huggingface.co/papers/{arxiv_id}\n" formatted = header + "\n\n".join(sections) return {"formatted": formatted, "totalResults": total, "resultsShared": total} # --------------------------------------------------------------------------- # Snippet search (Semantic Scholar) # --------------------------------------------------------------------------- def _format_snippets(snippets: list[dict], query: str) -> str: lines = [f"# Snippet Search: '{query}'"] lines.append(f"Found {len(snippets)} matching passage(s)\n") for i, item in enumerate(snippets, 1): paper = item.get("paper") or {} ptitle = paper.get("title") or "(untitled)" year = paper.get("year") or "?" cites = paper.get("citationCount", 0) ext_ids = paper.get("externalIds") or {} aid = ext_ids.get("ArXiv", "") snippet = item.get("snippet") or {} text = snippet.get("text", "") section = snippet.get("section") or "" lines.append(f"### {i}. {ptitle} ({year}, {cites} cites)") if aid: lines.append(f"arxiv:{aid}") if section: lines.append(f"Section: {section}") if text: lines.append(f"> {_truncate(text, 400)}") lines.append("") lines.append("Use paper_details or read_paper with arxiv_id to explore a paper further.") return "\n".join(lines) async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult: query = args.get("query") if not query: return _error("'query' is required for snippet_search.") params: dict[str, Any] = { "query": query, "limit": limit, "fields": "title,externalIds,year,citationCount", } # Optional filters (same as search) date_from = args.get("date_from", "") date_to = args.get("date_to", "") if date_from or date_to: params["publicationDateOrYear"] = f"{date_from}:{date_to}" if args.get("categories"): params["fieldsOfStudy"] = args["categories"] if args.get("min_citations"): params["minCitationCount"] = str(args["min_citations"]) async with httpx.AsyncClient(timeout=15) as client: resp = await _s2_request(client, "GET", "/graph/v1/snippet/search", params=params) if not resp or resp.status_code != 200: return _error("Snippet search failed. Semantic Scholar may be unavailable.") data = resp.json() snippets = data.get("data") or [] if not snippets: return { "formatted": f"No snippets found for '{query}'.", "totalResults": 0, "resultsShared": 0, } return { "formatted": _format_snippets(snippets, query), "totalResults": len(snippets), "resultsShared": len(snippets), } # --------------------------------------------------------------------------- # Recommendations (Semantic Scholar) # --------------------------------------------------------------------------- async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult: positive_ids = args.get("positive_ids") arxiv_id = _validate_arxiv_id(args) if not arxiv_id and not positive_ids: return _error("'arxiv_id' or 'positive_ids' is required for recommend.") fields = "title,externalIds,year,citationCount,tldr,venue" async with httpx.AsyncClient(timeout=15) as client: if positive_ids and not arxiv_id: # Multi-paper recommendations (POST, not cached) pos = [_s2_paper_id(pid.strip()) for pid in positive_ids.split(",") if pid.strip()] neg_raw = args.get("negative_ids", "") neg = [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] if neg_raw else [] resp = await _s2_request( client, "POST", "/recommendations/v1/papers/", json={"positivePaperIds": pos, "negativePaperIds": neg}, params={"fields": fields, "limit": limit}, ) if not resp or resp.status_code != 200: return _error("Recommendation request failed. Semantic Scholar may be unavailable.") data = resp.json() else: # Single-paper recommendations (cached) data = await _s2_get_json( client, f"/recommendations/v1/papers/forpaper/{_s2_paper_id(arxiv_id)}", {"fields": fields, "limit": limit, "from": "recent"}, ) if not data: return _error("Recommendation request failed. Semantic Scholar may be unavailable.") papers = data.get("recommendedPapers") or [] if not papers: return { "formatted": "No recommendations found.", "totalResults": 0, "resultsShared": 0, } title = f"Recommended papers based on {arxiv_id or positive_ids}" return { "formatted": _format_s2_paper_list(papers[:limit], title), "totalResults": len(papers), "resultsShared": min(limit, len(papers)), } # --------------------------------------------------------------------------- # Operation dispatch # --------------------------------------------------------------------------- _OPERATIONS = { "trending": _op_trending, "search": _op_search, "paper_details": _op_paper_details, "read_paper": _op_read_paper, "citation_graph": _op_citation_graph, "snippet_search": _op_snippet_search, "recommend": _op_recommend, "find_datasets": _op_find_datasets, "find_models": _op_find_models, "find_collections": _op_find_collections, "find_all_resources": _op_find_all_resources, } # --------------------------------------------------------------------------- # Tool spec + handler # --------------------------------------------------------------------------- HF_PAPERS_TOOL_SPEC = { "name": "hf_papers", "description": ( "Discover ML research papers, analyze citations, search paper contents, and find linked resources.\n\n" "Combines HuggingFace Hub, arXiv, and Semantic Scholar. Use for exploring research areas, " "finding datasets for a task, tracing citation chains, or implementing a paper's approach.\n\n" "Typical flows:\n" " search → read_paper → find_all_resources → hf_inspect_dataset\n" " search → paper_details → citation_graph → read_paper (trace influence)\n" " snippet_search → paper_details → read_paper (find specific claims)\n\n" "Operations:\n" "- trending: Get trending daily papers, optionally filter by topic keyword\n" "- search: Search papers. Uses HF by default (ML-tuned). Add date_from/min_citations/categories to use Semantic Scholar with filters\n" "- paper_details: Metadata, abstract, AI summary, github link\n" "- read_paper: Read paper contents — without section: abstract + TOC; with section: full text\n" "- citation_graph: Get references and citations for a paper with influence flags and citation intents\n" "- snippet_search: Semantic search over full-text passages from 12M+ papers\n" "- recommend: Find similar papers (single paper or positive/negative examples)\n" "- find_datasets: Find datasets linked to a paper\n" "- find_models: Find models linked to a paper\n" "- find_collections: Find collections that include a paper\n" "- find_all_resources: Parallel fetch of datasets + models + collections for a paper" ), "parameters": { "type": "object", "properties": { "operation": { "type": "string", "enum": list(_OPERATIONS.keys()), "description": "Operation to execute.", }, "query": { "type": "string", "description": ( "Search query. Required for: search, snippet_search. " "Optional for: trending (filters by keyword). " "Supports boolean syntax for Semantic Scholar: '\"exact phrase\" term1 | term2'." ), }, "arxiv_id": { "type": "string", "description": ( "ArXiv paper ID (e.g. '2305.18290'). " "Required for: paper_details, read_paper, citation_graph, find_datasets, find_models, find_collections, find_all_resources. " "Optional for: recommend (single-paper recs). Get IDs from search results first." ), }, "section": { "type": "string", "description": ( "Section name or number to read (e.g. '3', 'Experiments', '4.2'). " "Optional for: read_paper. Without this, returns abstract + TOC." ), }, "direction": { "type": "string", "enum": ["citations", "references", "both"], "description": "Direction for citation_graph. Default: both.", }, "date": { "type": "string", "description": "Date in YYYY-MM-DD format. Optional for: trending (defaults to recent papers).", }, "date_from": { "type": "string", "description": "Start date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", }, "date_to": { "type": "string", "description": "End date (YYYY-MM-DD). Triggers Semantic Scholar search. For: search, snippet_search.", }, "categories": { "type": "string", "description": "Field of study filter (e.g. 'Computer Science'). Triggers Semantic Scholar search.", }, "min_citations": { "type": "integer", "description": "Minimum citation count filter. Triggers Semantic Scholar search.", }, "sort_by": { "type": "string", "enum": ["relevance", "citationCount", "publicationDate"], "description": "Sort order for Semantic Scholar search. Default: relevance.", }, "positive_ids": { "type": "string", "description": "Comma-separated arxiv IDs for multi-paper recommendations. For: recommend.", }, "negative_ids": { "type": "string", "description": "Comma-separated arxiv IDs as negative examples. For: recommend.", }, "sort": { "type": "string", "enum": ["downloads", "likes", "trending"], "description": ( "Sort order for find_datasets and find_models. Default: downloads." ), }, "limit": { "type": "integer", "description": "Maximum results to return (default: 10, max: 50).", }, }, "required": ["operation"], }, } async def hf_papers_handler(arguments: dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router.""" operation = arguments.get("operation") if not operation: return "'operation' parameter is required.", False handler = _OPERATIONS.get(operation) if not handler: valid = ", ".join(_OPERATIONS.keys()) return f"Unknown operation: '{operation}'. Valid: {valid}", False limit = min(arguments.get("limit", DEFAULT_LIMIT), MAX_LIMIT) try: result = await handler(arguments, limit) return result["formatted"], not result.get("isError", False) except httpx.HTTPStatusError as e: return f"API error: {e.response.status_code} — {e.response.text[:200]}", False except httpx.RequestError as e: return f"Request error: {e}", False except Exception as e: return f"Error in {operation}: {e}", False ================================================ FILE: agent/tools/plan_tool.py ================================================ from typing import Any, Dict, List from agent.core.session import Event from agent.utils.terminal_display import format_plan_tool_output from .types import ToolResult # In-memory storage for the current plan (raw structure from agent) _current_plan: List[Dict[str, str]] = [] class PlanTool: """Tool for managing a list of todos with status tracking.""" def __init__(self, session: Any = None): self.session = session async def execute(self, params: Dict[str, Any]) -> ToolResult: """ Execute the WritePlan operation. Args: params: Dictionary containing: - todos: List of todo items, each with id, content, and status Returns: ToolResult with formatted output """ global _current_plan todos = params.get("todos", []) # Validate todos structure for todo in todos: if not isinstance(todo, dict): return { "formatted": "Error: Each todo must be an object. Re call the tool with correct format (mandatory).", "isError": True, } required_fields = ["id", "content", "status"] for field in required_fields: if field not in todo: return { "formatted": f"Error: Todo missing required field '{field}'. Re call the tool with correct format (mandatory).", "isError": True, } # Validate status valid_statuses = ["pending", "in_progress", "completed"] if todo["status"] not in valid_statuses: return { "formatted": f"Error: Invalid status '{todo['status']}'. Must be one of: {', '.join(valid_statuses)}. Re call the tool with correct format (mandatory).", "isError": True, } # Store the raw todos structure in memory _current_plan = todos # Emit plan update event if session is available if self.session: await self.session.send_event( Event( event_type="plan_update", data={"plan": todos}, ) ) # Format only for display using terminal_display utility formatted_output = format_plan_tool_output(todos) return { "formatted": formatted_output, "totalResults": len(todos), "isError": False, } def get_current_plan() -> List[Dict[str, str]]: """Get the current plan (raw structure).""" return _current_plan # Tool specification PLAN_TOOL_SPEC = { "name": "plan_tool", "description": ( "Track progress on multi-step tasks with a todo list (pending/in_progress/completed).\n\n" "Use for tasks with 3+ steps. Each call replaces the entire plan (send full list).\n\n" "Rules: exactly ONE task in_progress at a time. Mark completed immediately after finishing. " "Only mark completed when the task fully succeeded — keep in_progress if there are errors. " "Update frequently so the user sees progress." ), "parameters": { "type": "object", "properties": { "todos": { "type": "array", "description": "List of todo items", "items": { "type": "object", "properties": { "id": { "type": "string", "description": "Unique identifier for the todo", }, "content": { "type": "string", "description": "Description of the todo task", }, "status": { "type": "string", "enum": ["pending", "in_progress", "completed"], "description": "Current status of the todo", }, }, "required": ["id", "content", "status"], }, } }, "required": ["todos"], }, } async def plan_tool_handler( arguments: Dict[str, Any], session: Any = None ) -> tuple[str, bool]: tool = PlanTool(session=session) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) ================================================ FILE: agent/tools/private_hf_repo_tools.py ================================================ """ Private HF Repos Tool - Manage private Hugging Face repositories PRIMARY USE: Store job outputs, training scripts, and logs from HF Jobs. Since job results are ephemeral, this tool provides persistent storage in private repos. SECONDARY USE: Read back stored files and list repo contents. """ import asyncio from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import HfHubHTTPError from agent.tools.types import ToolResult # Operation names OperationType = Literal[ "upload_file", "create_repo", "check_repo", "list_files", "read_file" ] async def _async_call(func, *args, **kwargs): """Wrap synchronous HfApi calls for async context.""" return await asyncio.to_thread(func, *args, **kwargs) def _build_repo_url(repo_id: str, repo_type: str = "dataset") -> str: """Build the Hub URL for a repository.""" type_path = "" if repo_type == "model" else f"{repo_type}s" return f"https://huggingface.co/{type_path}/{repo_id}".replace("//", "/") def _content_to_bytes(content: str | bytes) -> bytes: """Convert string or bytes content to bytes.""" if isinstance(content, str): return content.encode("utf-8") return content class PrivateHfRepoTool: """Tool for managing private Hugging Face repositories.""" def __init__(self, hf_token: Optional[str] = None): self.api = HfApi(token=hf_token) async def execute(self, params: Dict[str, Any]) -> ToolResult: """Execute the specified upload operation.""" operation = params.get("operation") args = params.get("args", {}) # If no operation provided, return usage instructions if not operation: return self._show_help() # Normalize operation name operation = operation.lower() # Check if help is requested if args.get("help"): return self._show_operation_help(operation) try: # Route to appropriate handler if operation == "upload_file": return await self._upload_file(args) elif operation == "create_repo": return await self._create_repo(args) elif operation == "check_repo": return await self._check_repo(args) elif operation == "list_files": return await self._list_files(args) elif operation == "read_file": return await self._read_file(args) else: return { "formatted": f'Unknown operation: "{operation}"\n\n' "Available operations: upload_file, create_repo, check_repo, list_files, read_file\n\n" "Call this tool with no operation for full usage instructions.", "totalResults": 0, "resultsShared": 0, "isError": True, } except HfHubHTTPError as e: return { "formatted": f"API Error: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } except Exception as e: return { "formatted": f"Error executing {operation}: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } def _show_help(self) -> ToolResult: """Show usage instructions when tool is called with no arguments.""" usage_text = """# Private HF Repos Tool **PRIMARY USE:** Store job outputs, scripts, and logs from HF Jobs to private repos. Since job results are ephemeral, use this tool for persistent storage. **SECONDARY USE:** Read back stored files and list repo contents. ## Available Commands ### Write Operations - **upload_file** - Upload file content to a repository - **create_repo** - Create a new private repository ### Read Operations - **list_files** - List all files in a repository - **read_file** - Read content of a specific file from a repository - **check_repo** - Check if a repository exists ## Examples ### Upload a script to a dataset repo Call this tool with: ```json { "operation": "upload_file", "args": { "file_content": "import pandas as pd\\nprint('Hello from HF!')", "path_in_repo": "scripts/hello.py", "repo_id": "my-dataset", "repo_type": "dataset", "create_if_missing": true, "commit_message": "Add hello script" } } ``` ### Upload logs from a job Call this tool with: ```json { "operation": "upload_file", "args": { "file_content": "Job started...\\nJob completed successfully!", "path_in_repo": "jobs/job-abc123/logs.txt", "repo_id": "job-results", "create_if_missing": true } } ``` ### Create a repository Call this tool with: ```json { "operation": "create_repo", "args": { "repo_id": "my-results", "repo_type": "dataset" } } ``` ### Create a Space Call this tool with: ```json { "operation": "create_repo", "args": { "repo_id": "my-gradio-app", "repo_type": "space", "space_sdk": "gradio" } } ``` Note: Repositories are always created as private. For spaces, `space_sdk` is required (gradio, streamlit, static, or docker). ### Check if a repository exists Call this tool with: ```json { "operation": "check_repo", "args": { "repo_id": "my-dataset", "repo_type": "dataset" } } ``` ### List files in a repository Call this tool with: ```json { "operation": "list_files", "args": { "repo_id": "job-results", "repo_type": "dataset" } } ``` ### Read a file from a repository Call this tool with: ```json { "operation": "read_file", "args": { "repo_id": "job-results", "path_in_repo": "jobs/job-abc123/script.py", "repo_type": "dataset" } } ``` ## Repository Types - **dataset** (default) - For storing data, results, logs, scripts - **model** - For ML models and related artifacts - **space** - For Spaces and applications ## Tips - **Content-based**: Pass file content directly as strings or bytes, not file paths - **Repo ID format**: Use just the repo name (e.g., "my-dataset"). Username is automatically inferred from HF_TOKEN - **Automatic repo creation**: Set `create_if_missing: true` to auto-create repos (requires user approval) - **Organization**: Use path_in_repo to organize files (e.g., "jobs/job-123/script.py") - **After jobs**: Upload job scripts and logs after compute jobs complete for reproducibility """ return {"formatted": usage_text, "totalResults": 1, "resultsShared": 1} def _show_operation_help(self, operation: str) -> ToolResult: """Show help for a specific operation.""" help_text = f"Help for operation: {operation}\n\nCall with appropriate arguments. Use the main help for examples." return {"formatted": help_text, "totalResults": 1, "resultsShared": 1} async def _upload_file(self, args: Dict[str, Any]) -> ToolResult: """Upload file content to a Hub repository.""" # Validate required arguments file_content = args.get("file_content") path_in_repo = args.get("path_in_repo") repo_id = args.get("repo_id") if not file_content: return { "formatted": "file_content is required", "totalResults": 0, "resultsShared": 0, "isError": True, } if not path_in_repo: return { "formatted": "path_in_repo is required", "totalResults": 0, "resultsShared": 0, "isError": True, } if not repo_id: return { "formatted": "repo_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } repo_type = args.get("repo_type", "dataset") create_if_missing = args.get("create_if_missing", False) # Check if repo exists try: repo_exists = await _async_call( self.api.repo_exists, repo_id=repo_id, repo_type=repo_type ) # Create repo if needed if not repo_exists and create_if_missing: create_args = { "repo_id": repo_id, "repo_type": repo_type, "private": True, } # Pass through space_sdk if provided (required for spaces) if "space_sdk" in args: create_args["space_sdk"] = args["space_sdk"] await self._create_repo(create_args) elif not repo_exists: return { "formatted": f"Repository {repo_id} does not exist. Set create_if_missing: true to create it.", "totalResults": 0, "resultsShared": 0, "isError": True, } except Exception as e: return { "formatted": f"Failed to check repository: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } # Convert content to bytes file_bytes = _content_to_bytes(file_content) # Upload file try: await _async_call( self.api.upload_file, path_or_fileobj=file_bytes, path_in_repo=path_in_repo, repo_id=repo_id, repo_type=repo_type, commit_message=args.get("commit_message", f"Upload {path_in_repo}"), ) repo_url = _build_repo_url(repo_id, repo_type) file_url = f"{repo_url}/blob/main/{path_in_repo}" response = f"""✓ File uploaded successfully! **Repository:** {repo_id} **File:** {path_in_repo} **View at:** {file_url} **Browse repo:** {repo_url}""" return {"formatted": response, "totalResults": 1, "resultsShared": 1} except Exception as e: return { "formatted": f"Failed to upload file: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } async def _create_repo(self, args: Dict[str, Any]) -> ToolResult: """Create a new Hub repository.""" repo_id = args.get("repo_id") if not repo_id: return { "formatted": "repo_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } repo_type = args.get("repo_type", "dataset") private = True # Always create private repos space_sdk = args.get("space_sdk") # Required if repo_type is "space" try: # Check if repo already exists repo_exists = await _async_call( self.api.repo_exists, repo_id=repo_id, repo_type=repo_type ) if repo_exists: repo_url = _build_repo_url(repo_id, repo_type) return { "formatted": f"Repository {repo_id} already exists.\n**View at:** {repo_url}", "totalResults": 1, "resultsShared": 1, } # Validate space_sdk for spaces if repo_type == "space" and not space_sdk: return { "formatted": "space_sdk is required when creating a space. Valid values: gradio, streamlit, static, docker", "totalResults": 0, "resultsShared": 0, "isError": True, } # Create repository create_kwargs = { "repo_id": repo_id, "repo_type": repo_type, "private": private, "exist_ok": True, } # Add space_sdk only for spaces if repo_type == "space" and space_sdk: create_kwargs["space_sdk"] = space_sdk repo_url = await _async_call(self.api.create_repo, **create_kwargs) response = f"""✓ Repository created successfully! **Repository:** {repo_id} **Type:** {repo_type} **Private:** Yes **View at:** {repo_url}""" return {"formatted": response, "totalResults": 1, "resultsShared": 1} except Exception as e: return { "formatted": f"Failed to create repository: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } async def _check_repo(self, args: Dict[str, Any]) -> ToolResult: """Check if a Hub repository exists.""" repo_id = args.get("repo_id") if not repo_id: return { "formatted": "repo_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } repo_type = args.get("repo_type", "dataset") try: repo_exists = await _async_call( self.api.repo_exists, repo_id=repo_id, repo_type=repo_type ) if repo_exists: repo_url = _build_repo_url(repo_id, repo_type) response = f"""✓ Repository exists! **Repository:** {repo_id} **Type:** {repo_type} **View at:** {repo_url}""" else: response = f"""Repository does not exist: {repo_id} To create it, call this tool with: ```json {{ "operation": "create_repo", "args": {{ "repo_id": "{repo_id}", "repo_type": "{repo_type}" }} }} ```""" return { "formatted": response, "totalResults": 1 if repo_exists else 0, "resultsShared": 1 if repo_exists else 0, } except Exception as e: return { "formatted": f"Failed to check repository: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } async def _list_files(self, args: Dict[str, Any]) -> ToolResult: """List all files in a Hub repository.""" repo_id = args.get("repo_id") if not repo_id: return { "formatted": "repo_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } repo_type = args.get("repo_type", "dataset") try: # List all files in the repository files = await _async_call( self.api.list_repo_files, repo_id=repo_id, repo_type=repo_type ) if not files: return { "formatted": f"No files found in repository: {repo_id}", "totalResults": 0, "resultsShared": 0, } # Format file list file_list = "\n".join(f"- {f}" for f in sorted(files)) repo_url = _build_repo_url(repo_id, repo_type) response = f"""✓ Files in repository: {repo_id} **Total files:** {len(files)} **Repository URL:** {repo_url} **Files:** {file_list}""" return { "formatted": response, "totalResults": len(files), "resultsShared": len(files), } except Exception as e: return { "formatted": f"Failed to list files: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } async def _read_file(self, args: Dict[str, Any]) -> ToolResult: """Read content of a specific file from a Hub repository.""" repo_id = args.get("repo_id") path_in_repo = args.get("path_in_repo") if not repo_id: return { "formatted": "repo_id is required", "totalResults": 0, "resultsShared": 0, "isError": True, } if not path_in_repo: return { "formatted": "path_in_repo is required", "totalResults": 0, "resultsShared": 0, "isError": True, } repo_type = args.get("repo_type", "dataset") try: # Download file to cache and read it file_path = await _async_call( hf_hub_download, repo_id=repo_id, filename=path_in_repo, repo_type=repo_type, token=self.api.token, ) # Read file content with open(file_path, "r", encoding="utf-8") as f: content = f.read() repo_url = _build_repo_url(repo_id, repo_type) file_url = f"{repo_url}/blob/main/{path_in_repo}" response = f"""✓ File read successfully! **Repository:** {repo_id} **File:** {path_in_repo} **Size:** {len(content)} characters **View at:** {file_url} **Content:** ``` {content} ```""" return {"formatted": response, "totalResults": 1, "resultsShared": 1} except UnicodeDecodeError: # If file is binary, return size info instead try: with open(file_path, "rb") as f: binary_content = f.read() return { "formatted": f"File is binary ({len(binary_content)} bytes). Cannot display as text.", "totalResults": 1, "resultsShared": 1, } except Exception as e: return { "formatted": f"Failed to read binary file: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } except Exception as e: return { "formatted": f"Failed to read file: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } # Tool specification for agent registration PRIVATE_HF_REPO_TOOL_SPEC = { "name": "hf_private_repos", "description": ( "Manage private HF repositories - create, upload, read, list files in models/datasets/spaces. " "⚠️ PRIMARY USE: Store job outputs persistently (job storage is EPHEMERAL - everything deleted after completion). " "**Use when:** (1) Job completes and need to store logs/scripts/results, (2) Creating repos for training outputs, " "(3) Reading back stored files, (4) Managing Space files, (5) Organizing job artifacts by path. " "**Pattern:** hf_jobs (ephemeral) → hf_private_repos upload_file (persistent) → can read_file later. " "ALWAYS pass file_content as string/bytes (✓), never file paths (✗) - this is content-based, no filesystem access. " "**Operations:** create_repo (new private repo), upload_file (store content), read_file (retrieve content), list_files (browse), check_repo (verify exists). " "**Critical for reliability:** Jobs lose all files after completion - use this tool to preserve important outputs. " "Repositories created are ALWAYS private by default (good for sensitive training data/models). " "For Spaces: must provide space_sdk ('gradio', 'streamlit', 'static', 'docker') when creating. " "**Then:** After uploading, provide user with repository URL for viewing/sharing." ), "parameters": { "type": "object", "properties": { "operation": { "type": "string", "enum": [ "upload_file", "create_repo", "check_repo", "list_files", "read_file", ], "description": ( "Operation to execute. Valid values: [upload_file, create_repo, check_repo, list_files, read_file]" ), }, "args": { "type": "object", "description": ( "Operation-specific arguments as a JSON object. " "Write ops: file_content (string/bytes), path_in_repo (string), repo_id (string), " "repo_type (dataset/model/space), create_if_missing (boolean), commit_message (string), " "space_sdk (gradio/streamlit/static/docker - required when repo_type=space). " "Read ops: repo_id (string), path_in_repo (for read_file), repo_type (optional)." ), "additionalProperties": True, }, }, }, } async def private_hf_repo_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router.""" try: tool = PrivateHfRepoTool() result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error executing Private HF Repo tool: {str(e)}", False ================================================ FILE: agent/tools/research_tool.py ================================================ """ Research subagent tool — spawns a cheap LLM call with a focused research task and returns a summary. The subagent gets its own independent context (not the main conversation), so research work doesn't pollute the main agent's context window. Inspired by claude-code's code-explorer agent pattern. """ import json import logging from typing import Any from litellm import Message, acompletion from agent.core.doom_loop import check_for_doom_loop from agent.core.llm_params import _resolve_llm_params from agent.core.prompt_caching import with_prompt_caching from agent.core.session import Event logger = logging.getLogger(__name__) # Context budget for the research subagent (tokens). # When usage exceeds WARN threshold, the subagent is told to wrap up. # At MAX, the loop is force-stopped and whatever content exists is returned. _RESEARCH_CONTEXT_WARN = 170_000 # 85% of 200k _RESEARCH_CONTEXT_MAX = 190_000 # Tools the research agent can use (read-only subset) RESEARCH_TOOL_NAMES = { "read", "bash", "explore_hf_docs", "fetch_hf_docs", "find_hf_api", "hf_papers", "github_find_examples", "github_list_repos", "github_read_file", "hf_inspect_dataset", "hf_repo_files", } RESEARCH_SYSTEM_PROMPT = """\ You are a research sub-agent for an ML engineering assistant. Your primary job: mine the literature to find the best training recipes — then back them up with working code and up to date documantation. The main agent will use your findings to implement the actual solution. # Start from the literature Your default approach is a deep literature crawl. Do not start from docs or example scripts — start from papers. Papers contain the results, and results tell you what actually works. ## The crawl 1. **Find anchor papers**: Search for the task/domain. Identify the landmark paper(s) — high citations, recent, or both. 2. **Crawl the citation graph**: Use `citation_graph` on the anchor paper(s). Look DOWNSTREAM (papers that cite it) — these are the ones that built on it, improved it, or applied it to new domains. Prioritize recent papers and papers with many citations. 3. **Read methodology sections**: For the most promising papers (strong results, recent, relevant), use `read_paper` with section parameter to read sections 3, 4, 5 (Methodology, Experiments, Results — not the abstract). Extract: - The exact dataset(s) used (name, source, size, any filtering/preprocessing) - The training method and configuration (optimizer, lr, schedule, epochs, batch size) - The results those choices produced (benchmark scores, metrics, comparisons) 4. **Attribute results to recipes**: This is the critical step. Every finding must link a RESULT to the RECIPE that produced it. "Dataset X + method Y + lr Z → score W on benchmark V" is useful. "They used SFT" is not. 5. **Validate datasets**: For the most promising datasets, check if they exist on HF Hub with `hf_inspect_dataset`. Verify format matches the training method. Report if doesnt. 6. **Find code**: Now find working implementation code via `github_find_examples` and `github_read_file`. Use docs (`explore_hf_docs`, `fetch_hf_docs`) to fill in API details. ## When to go deeper - If the anchor paper is old (>1 year), its citation graph is your main source — the downstream papers will have better methods. - If a downstream paper reports significantly better results, crawl ITS citation graph too. - Use `snippet_search` to find specific claims across papers (e.g., "does dataset X consistently outperform Y for this task?"). - Use `recommend` to find related papers the citation graph might miss. # How to use your tools ## Papers & citations (USE FIRST) - `hf_papers(operation="search", query=...)`: Search papers (HF-tuned for ML) - `hf_papers(operation="search", query=..., min_citations=50, sort_by="citationCount")`: Find highly-cited papers via Semantic Scholar - `hf_papers(operation="search", query=..., date_from="2024-01-01")`: Search with date filter - `hf_papers(operation="paper_details", arxiv_id=...)`: Metadata, citations, TL;DR - `hf_papers(operation="citation_graph", arxiv_id=...)`: References + citations with influence flags and intents - `hf_papers(operation="read_paper", arxiv_id=..., section="3")`: Read a specific section's full text - `hf_papers(operation="read_paper", arxiv_id=...)`: Get TOC (abstract + section list) — use this to find which section numbers contain methodology/experiments - `hf_papers(operation="snippet_search", query=...)`: Semantic search across 12M+ full-text paper passages - `hf_papers(operation="recommend", arxiv_id=...)`: Find related papers - `hf_papers(operation="find_datasets", arxiv_id=...)`: Find HF datasets linked to a paper - `hf_papers(operation="find_all_resources", arxiv_id=...)`: Datasets + models + collections for a paper ## Dataset inspection - `hf_inspect_dataset`: Check dataset schema, splits, sample rows CRITICAL for training: verify column format matches training method: - SFT: needs "messages", "text", or "prompt"/"completion" - DPO: needs "prompt", "chosen", "rejected" - GRPO: needs "prompt" only ## GitHub code research - `github_find_examples`: Find working example scripts in HF repos (trl, transformers, etc.) - `github_read_file`: Read the actual implementation code. Use line_start/line_end for large files. ## Documentation - `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc. - `fetch_hf_docs(url)`: Fetch full page content from explore results - `find_hf_api(query=..., tag=...)`: Find REST API endpoints ## Hub repo inspection - `hf_repo_files`: List/read files in any HF repo (model, dataset, space) # Correct research pattern ``` # 1. Find anchor paper(s) for the task hf_papers({"operation": "search", "query": "GPQA graduate questions", "sort_by": "citationCount"}) # 2. Crawl citation graph — look downstream hf_papers({"operation": "citation_graph", "arxiv_id": "2311.12022", "direction": "citations"}) # 3. Read methodology of promising downstream papers hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348"}) # TOC first hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "3"}) # Methodology hf_papers({"operation": "read_paper", "arxiv_id": "2604.01348", "section": "4"}) # Experiments # 4. Find datasets used by these papers hf_papers({"operation": "find_datasets", "arxiv_id": "2604.01348"}) hf_papers({"operation": "find_all_resources", "arxiv_id": "2604.01348"}) # 5. Validate datasets exist and have correct format hf_inspect_dataset({"dataset": "org/dataset-name", "split": "train", "sample_rows": 3}) # 6. Now get working code for the training method github_find_examples({"repo": "trl", "keyword": "sft"}) github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"}) explore_hf_docs("trl") ``` # Output format Your output MUST be structured as a ranked list of training recipes, each attributed to published results: ## Recipe table (REQUIRED) For each promising approach found, report: - **Paper**: title, arxiv_id, date, venue - **Result**: exact benchmark scores and what they were measured on - **Dataset(s)**: name, size, source, HF Hub availability, format verified (yes/no) - **Method**: training approach, key hyperparameters (lr, epochs, batch size, optimizer, schedule) - **What made it work**: the specific insight or trick that drove the result (data curation, curriculum, loss function, etc.) Rank recipes by result quality. The main agent will pick the best one that's feasible. ## Code patterns - Key imports, configurations, and usage patterns from working examples - Specific file paths, URLs, function names from docs ## Recommendations - Which recipe to implement first and why - What datasets to use (with HF Hub paths, verified) - Any gaps: datasets that need preprocessing, methods that need adaptation Additionally include: - **SOTA landscape**: Current best models, datasets, and methods for the task (from recent papers). Flag anything outdated. - **Essential references**: Specific file paths, URLs, function names, doc sections, code snippets that the main agent should use directly - **Code patterns**: Key imports, configurations, and usage patterns from working examples Be concise. Your output goes into another agent's context — every token counts. Aim for 500-1500 words max. Include actual code snippets from examples you read, not paraphrased descriptions. """ RESEARCH_TOOL_SPEC = { "name": "research", "description": ( "Spawn a research sub-agent to explore documentation, codebases, " "or repos WITHOUT polluting the main conversation context. " "The sub-agent gets its own independent context window with read-only " "research tools and returns a concise summary of findings.\n\n" "Use this for:\n" "- Researching current API usage before implementing ML tasks " "(find examples + read docs)\n" "- Exploring HF docs, reading papers, analyzing GitHub repos\n" "- Any research where raw tool outputs would be too verbose\n\n" "The sub-agent knows how to use github_find_examples, github_read_file, " "explore_hf_docs, fetch_hf_docs, hf_inspect_dataset, hf_papers, etc. " "Just describe what you need researched." ), "parameters": { "type": "object", "properties": { "task": { "type": "string", "description": ( "Detailed description of what to research. Be specific: " "include library names, trainer types, dataset names, " "repo names, or doc pages to explore. Example: " "'Research current TRL SFTTrainer usage: find working " "example scripts, read the SFT documentation, and check " "SFTConfig parameters. Also validate that dataset " "HuggingFaceH4/ultrachat_200k has the right format for SFT.'" ), }, "context": { "type": "string", "description": ( "Optional context from the current conversation that the " "research agent needs (e.g., what the user wants to build, " "constraints, what's been tried)." ), }, }, "required": ["task"], }, } def _get_research_model(main_model: str) -> str: """Pick a cheaper model for research based on the main model.""" if "anthropic" in main_model: return "bedrock/us.anthropic.claude-sonnet-4-6" # For non-Anthropic models (HF router etc.), use the same model return main_model async def research_handler( arguments: dict[str, Any], session=None, tool_call_id: str | None = None, **_kw ) -> tuple[str, bool]: """Execute a research sub-agent with its own context.""" task = arguments.get("task", "") context = arguments.get("context", "") if not task: return "No research task provided.", False if not session: return "No session available for research agent.", False # Build the sub-agent's messages (independent context) messages: list[Message] = [ Message(role="system", content=RESEARCH_SYSTEM_PROMPT), ] user_content = f"Research task: {task}" if context: user_content = f"Context: {context}\n\n{user_content}" messages.append(Message(role="user", content=user_content)) # Use a cheaper/faster model for research main_model = session.config.model_name research_model = _get_research_model(main_model) # Research is a cheap sub-call — cap the main session's effort at "high" # so a user preference of ``max``/``xhigh`` (valid for Opus 4.6/4.7) doesn't # propagate to a Sonnet research model that may not accept those levels. # We also haven't probed this sub-model so we don't know its ceiling. _pref = getattr(session.config, "reasoning_effort", None) _capped = "high" if _pref in ("max", "xhigh") else _pref llm_params = _resolve_llm_params( research_model, getattr(session, "hf_token", None), reasoning_effort=_capped, ) # Get read-only tool specs from the session's tool router tool_specs = [ spec for spec in session.tool_router.get_tool_specs_for_llm() if spec["function"]["name"] in RESEARCH_TOOL_NAMES ] # Unique ID + short label so parallel agents show separate status lines. # Use the tool_call_id when available — it's unique per invocation and lets # the frontend match a research tool card to its agent state. Fall back to # uuid for offline/test paths. Previously used md5(task), which collided # when the same task string was researched in parallel. if tool_call_id: _agent_id = tool_call_id else: import uuid _agent_id = uuid.uuid4().hex[:8] _agent_label = "research: " + (task[:50] + "…" if len(task) > 50 else task) async def _log(text: str) -> None: """Send a progress event to the UI so it doesn't look frozen.""" try: await session.send_event( Event(event_type="tool_log", data={ "tool": "research", "log": text, "agent_id": _agent_id, "label": _agent_label, }) ) except Exception: pass _tool_uses = 0 _total_tokens = 0 _warned_context = False await _log("Starting research sub-agent...") # Run the research loop — context budget is the real limiter max_iterations = 60 for _iteration in range(max_iterations): # ── Doom-loop detection ── doom_prompt = check_for_doom_loop(messages) if doom_prompt: logger.warning("Research sub-agent doom loop detected at iteration %d", _iteration) await _log("Doom loop detected — injecting corrective prompt") messages.append(Message(role="user", content=doom_prompt)) # ── Context budget: warn at 75%, hard-stop at 95% ── if _total_tokens >= _RESEARCH_CONTEXT_MAX: logger.warning( "Research sub-agent hit context max (%d tokens) — forcing summary", _total_tokens, ) await _log(f"Context limit reached ({_total_tokens} tokens) — forcing wrap-up") # Ask for a final summary with no tools messages.append(Message( role="user", content=( "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. " "Summarize your findings NOW. Do NOT call any more tools." ), )) try: _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model")) response = await acompletion( messages=_msgs, tools=None, # no tools — force text response stream=False, timeout=120, **llm_params, ) content = response.choices[0].message.content or "" return content or "Research context exhausted — no summary produced.", bool(content) except Exception: return "Research context exhausted and summary call failed.", False if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN: _warned_context = True await _log(f"Context at {_total_tokens} tokens — nudging to wrap up") messages.append(Message( role="user", content=( "[SYSTEM: You have used 75% of your context budget. " "Start wrapping up: finish any critical lookups, then " "produce your final summary within the next 1-2 iterations.]" ), )) try: _msgs, _tools = with_prompt_caching( messages, tool_specs if tool_specs else None, llm_params.get("model") ) response = await acompletion( messages=_msgs, tools=_tools, tool_choice="auto", stream=False, timeout=120, **llm_params, ) except Exception as e: logger.error("Research sub-agent LLM error: %s", e) return f"Research agent LLM error: {e}", False # Track tokens if response.usage: _total_tokens = response.usage.total_tokens await _log(f"tokens:{_total_tokens}") choice = response.choices[0] msg = choice.message # If no tool calls, we have our final answer if not msg.tool_calls: await _log("Research complete.") content = msg.content or "Research completed but no summary generated." return content, True # Execute tool calls and add results. # Rebuild the assistant message with only the wire-safe fields — # LiteLLM's raw Message carries `provider_specific_fields` and # `reasoning_content`, which the HF router's OpenAI schema rejects # if we echo them back in the next request. messages.append(Message( role="assistant", content=msg.content, tool_calls=msg.tool_calls, )) for tc in msg.tool_calls: try: tool_args = json.loads(tc.function.arguments) except (json.JSONDecodeError, TypeError): messages.append( Message( role="tool", content="Invalid tool arguments.", tool_call_id=tc.id, name=tc.function.name, ) ) continue tool_name = tc.function.name if tool_name not in RESEARCH_TOOL_NAMES: messages.append( Message( role="tool", content=f"Tool '{tool_name}' not available for research.", tool_call_id=tc.id, name=tool_name, ) ) continue try: import json as _json args_str = _json.dumps(tool_args)[:80] await _log(f"▸ {tool_name} {args_str}") output, _success = await session.tool_router.call_tool( tool_name, tool_args, session=session ) _tool_uses += 1 await _log(f"tools:{_tool_uses}") # Truncate tool output for the research context if len(output) > 8000: output = output[:4800] + "\n...(truncated)...\n" + output[-3200:] except Exception as e: output = f"Tool error: {e}" messages.append( Message( role="tool", content=output, tool_call_id=tc.id, name=tool_name, ) ) # ── Iteration limit: try to salvage findings ── await _log("Iteration limit reached — extracting summary") messages.append(Message( role="user", content=( "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research " "iterations. Summarize ALL findings so far. Do NOT call any more tools." ), )) try: _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model")) response = await acompletion( messages=_msgs, tools=None, stream=False, timeout=120, **llm_params, ) content = response.choices[0].message.content or "" if content: return content, True except Exception as e: logger.error("Research summary call failed: %s", e) return ( "Research agent hit iteration limit (60). " "Partial findings may be incomplete — try a more focused task.", False, ) ================================================ FILE: agent/tools/sandbox_client.py ================================================ #!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = ["huggingface_hub>=0.20.0", "httpx>=0.27.0"] # /// """ Sandbox Tools — Agent-native primitives for HF Space dev-mode sandboxes. Architecture: - Creates a sandbox by duplicating a template Space (runs sandbox_server.py) - Waits for it to come online - Communicates via HTTPS to the Space's API - Optionally deletes the Space when done Lifecycle: sb = Sandbox.create(owner="burtenshaw") # duplicate, wait, connect sb = Sandbox.create(owner="burtenshaw", # with options hardware="t4-small", private=True, sleep_time=3600) sb = Sandbox.connect("burtenshaw/my-sandbox-abc") # attach to existing sb.bash("uv run train.py") sb.read("/app/train.py") sb.edit("/app/train.py", old_str="lr=1e-3", new_str="lr=1e-4") sb.delete() # tear down when done # Or use as a context manager for automatic cleanup with Sandbox.create(owner="burtenshaw") as sb: sb.bash("python train.py") # Space deleted on exit Tools: bash, read, write, edit, upload """ from __future__ import annotations import io import sys import time import uuid from dataclasses import dataclass, field from typing import Any, Callable import httpx from huggingface_hub import CommitOperationAdd, HfApi TEMPLATE_SPACE = "burtenshaw/sandbox" HARDWARE_OPTIONS = [ "cpu-basic", "cpu-upgrade", "t4-small", "t4-medium", "a10g-small", "a10g-large", "a100-large", ] OUTPUT_LIMIT = 25000 LINE_LIMIT = 4000 DEFAULT_READ_LIMIT = 2000 DEFAULT_TIMEOUT = 240 MAX_TIMEOUT = 1200 WAIT_TIMEOUT = 600 WAIT_INTERVAL = 5 API_WAIT_TIMEOUT = 180 _DOCKERFILE = """\ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim RUN apt-get update && \\ apt-get install -y \\ bash git git-lfs wget curl procps \\ htop vim nano jq tmux \\ build-essential && \\ rm -rf /var/lib/apt/lists/* RUN uv pip install --system fastapi uvicorn python-multipart RUN useradd -m -u 1000 user USER user ENV HOME=/home/user \\ PATH=/home/user/.local/bin:$PATH \\ PIP_USER=1 \\ HF_HUB_DISABLE_PROGRESS_BARS=1 \\ TQDM_DISABLE=1 \\ HF_HUB_ENABLE_HF_TRANSFER=1 \\ UV_NO_PROGRESS=1 \\ PYTHONWARNINGS=ignore::DeprecationWarning WORKDIR /app COPY --chown=user . /app EXPOSE 7860 CMD ["python", "sandbox_server.py"] """ _SANDBOX_SERVER = '''\ """Minimal FastAPI server for sandbox operations.""" import os, subprocess, pathlib, signal, threading, re, tempfile from fastapi import FastAPI from pydantic import BaseModel from typing import Optional import uvicorn _ANSI_RE = re.compile(r'\\x1b\\[[0-9;]*[a-zA-Z]|\\x1b\\].*?\\x07') def _strip_ansi(text: str) -> str: return _ANSI_RE.sub('', text) def _truncate_output(output: str, max_chars: int = 25000, head_ratio: float = 0.25) -> str: if len(output) <= max_chars: return output # Write full output to temp file so LLM can read specific sections spill_path = None try: with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', prefix='bash_output_', dir='/tmp', delete=False) as f: f.write(output) spill_path = f.name except Exception: pass head_budget = int(max_chars * head_ratio) tail_budget = max_chars - head_budget head = output[:head_budget] tail = output[-tail_budget:] total = len(output) omitted = total - max_chars meta = f"\\n\\n... ({omitted:,} of {total:,} chars omitted, showing first {head_budget:,} + last {tail_budget:,}) ...\\n" if spill_path: meta += f"Full output saved to {spill_path} — use the read tool with offset/limit to inspect specific sections.\\n" return head + meta + tail def _atomic_write(path: pathlib.Path, content: str): """Write atomically: temp file + fsync + os.replace.""" path.parent.mkdir(parents=True, exist_ok=True) fd = None tmp_path = None try: fd, tmp_path = tempfile.mkstemp(dir=str(path.parent), suffix=".tmp") os.write(fd, content.encode("utf-8")) os.fsync(fd) os.close(fd) fd = None os.replace(tmp_path, str(path)) tmp_path = None finally: if fd is not None: os.close(fd) if tmp_path is not None: try: os.unlink(tmp_path) except OSError: pass app = FastAPI() # Track active bash processes so they can be killed on cancel _active_procs = {} # pid -> subprocess.Popen _proc_lock = threading.Lock() class BashReq(BaseModel): command: str work_dir: str = "/app" timeout: int = 120 class ReadReq(BaseModel): path: str offset: Optional[int] = None limit: Optional[int] = 2000 class WriteReq(BaseModel): path: str content: str class EditReq(BaseModel): path: str old_str: str new_str: str replace_all: bool = False mode: str = "replace" class ExistsReq(BaseModel): path: str # ── Fuzzy matching & edit utilities (embedded) ── UNICODE_MAP = { "\\u2013": "-", "\\u2014": "-", "\\u2212": "-", "\\u2018": "'", "\\u2019": "'", "\\u201c": \'"\', "\\u201d": \'"\', "\\u00a0": " ", "\\u2003": " ", "\\u2002": " ", "\\u200b": "", "\\ufeff": "", } def _normalize_unicode(s): return "".join(UNICODE_MAP.get(c, c) for c in s) def _fuzzy_find_original(content, pattern): """Find the original text in content that matches pattern fuzzily.""" if pattern in content: return pattern, None # Pass 2: right-trim c_lines = content.split("\\n") c_rt = "\\n".join(l.rstrip() for l in c_lines) p_rt = "\\n".join(l.rstrip() for l in pattern.split("\\n")) if p_rt in c_rt: idx = c_rt.index(p_rt) start_line = c_rt[:idx].count("\\n") n_lines = p_rt.count("\\n") + 1 matched = "\\n".join(c_lines[start_line:start_line + n_lines]) return matched, "(matched after trimming trailing whitespace)" # Pass 3: both-sides trim c_st = "\\n".join(l.strip() for l in c_lines) p_st = "\\n".join(l.strip() for l in pattern.split("\\n")) if p_st in c_st: idx = c_st.index(p_st) start_line = c_st[:idx].count("\\n") n_lines = p_st.count("\\n") + 1 matched = "\\n".join(c_lines[start_line:start_line + n_lines]) return matched, "(matched after trimming whitespace)" # Pass 4: unicode normalization c_norm = _normalize_unicode(c_st) p_norm = _normalize_unicode(p_st) if p_norm in c_norm: idx = c_norm.index(p_norm) start_line = c_norm[:idx].count("\\n") n_lines = p_norm.count("\\n") + 1 matched = "\\n".join(c_lines[start_line:start_line + n_lines]) return matched, "(matched after unicode normalization)" return None, None def _apply_edit(content, old_str, new_str, mode="replace", replace_all=False): """Apply edit. Returns (new_content, count, fuzzy_note) or raises ValueError.""" if mode == "replace_all": replace_all = True mode = "replace" fuzzy_note = None if old_str not in content: matched, fuzzy_note = _fuzzy_find_original(content, old_str) if matched is None: raise ValueError("old_str not found in file.") old_str = matched count = content.count(old_str) if mode == "replace": if count > 1 and not replace_all: raise ValueError(f"old_str appears {count} times. Use replace_all=true or provide more context.") if replace_all: return content.replace(old_str, new_str), count, fuzzy_note return content.replace(old_str, new_str, 1), 1, fuzzy_note elif mode == "append_after": if replace_all: return content.replace(old_str, old_str + new_str), count, fuzzy_note idx = content.index(old_str) + len(old_str) return content[:idx] + new_str + content[idx:], 1, fuzzy_note elif mode == "prepend_before": if replace_all: return content.replace(old_str, new_str + old_str), count, fuzzy_note idx = content.index(old_str) return content[:idx] + new_str + content[idx:], 1, fuzzy_note raise ValueError(f"Unknown mode: {mode}") def _validate_python(content, path=""): """Validate Python: syntax, kwargs against real installed signatures, training heuristics. Runs inside the sandbox where packages are pip-installed, so we can actually import classes and inspect their __init__ signatures to catch kwarg mismatches before runtime. """ import ast as _ast, inspect as _inspect, importlib as _il warnings = [] # 1. Syntax check try: tree = _ast.parse(content) except SyntaxError as e: warnings.append(f"Python syntax error at line {e.lineno}: {e.msg}") return warnings # 2. Build import map: name -> module path (from the script's own imports) import_map = {} for node in _ast.walk(tree): if isinstance(node, _ast.ImportFrom) and node.module: for alias in (node.names or []): local_name = alias.asname or alias.name import_map[local_name] = (node.module, alias.name) elif isinstance(node, _ast.Import): for alias in (node.names or []): local_name = alias.asname or alias.name import_map[local_name] = (alias.name, None) # 3. For each Call node, resolve the callable and check kwargs against signature for node in _ast.walk(tree): if not isinstance(node, _ast.Call): continue # Skip calls with **kwargs unpacking — we can't statically know those keys if any(kw.arg is None for kw in node.keywords): continue call_kwargs = [kw.arg for kw in node.keywords if kw.arg] if not call_kwargs: continue # Resolve the callable name func_name = None if isinstance(node.func, _ast.Name): func_name = node.func.id elif isinstance(node.func, _ast.Attribute): func_name = node.func.attr if not func_name or func_name not in import_map: continue # Try to import and inspect the real callable module_path, attr_name = import_map[func_name] try: mod = _il.import_module(module_path) obj = getattr(mod, attr_name, None) if attr_name else mod if obj is None: continue sig = _inspect.signature(obj) params = sig.parameters # If **kwargs is in the signature, any kwarg is valid if any(p.kind == _inspect.Parameter.VAR_KEYWORD for p in params.values()): continue valid_names = set(params.keys()) for kw_name in call_kwargs: if kw_name not in valid_names: warnings.append( f"Invalid kwarg: {func_name}({kw_name}=...) at line {node.lineno} " f"-- not accepted by {module_path}.{attr_name or func_name}()" ) except Exception: pass # can't import/inspect — skip silently # 4. Training script heuristics if any(kw in content for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")): if "push_to_hub" not in content: warnings.append("Training script warning: no \'push_to_hub\' found") if "hub_model_id" not in content: warnings.append("Training script warning: no \'hub_model_id\' found") return warnings @app.get("/api/health") def health(): return {"status": "ok"} @app.post("/api/bash") def bash(req: BashReq): try: proc = subprocess.Popen( req.command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=req.work_dir, start_new_session=True, ) with _proc_lock: _active_procs[proc.pid] = proc try: stdout, stderr = proc.communicate(timeout=req.timeout) output = _strip_ansi(stdout + stderr) output = _truncate_output(output) return {"success": proc.returncode == 0, "output": output, "error": "" if proc.returncode == 0 else f"Exit code {proc.returncode}"} except subprocess.TimeoutExpired: try: os.killpg(os.getpgid(proc.pid), signal.SIGKILL) except OSError: proc.kill() proc.wait() return {"success": False, "output": "", "error": f"Timeout after {req.timeout}s"} finally: with _proc_lock: _active_procs.pop(proc.pid, None) except Exception as e: return {"success": False, "output": "", "error": str(e)} @app.post("/api/kill") def kill_all(): """Kill all active bash processes. Called when user cancels.""" with _proc_lock: pids = list(_active_procs.keys()) killed = [] for pid in pids: try: os.killpg(os.getpgid(pid), signal.SIGTERM) killed.append(pid) except OSError: try: os.kill(pid, signal.SIGKILL) killed.append(pid) except OSError: pass return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""} @app.post("/api/read") def read(req: ReadReq): try: p = pathlib.Path(req.path) if not p.exists(): return {"success": False, "output": "", "error": f"File not found: {req.path}"} if p.is_dir(): return {"success": False, "output": "", "error": f"Is a directory: {req.path}"} lines = p.read_text().splitlines() start = (req.offset or 1) - 1 end = start + (req.limit or len(lines)) selected = lines[start:end] numbered = "\\n".join(f"{start + i + 1}\\t{line}" for i, line in enumerate(selected)) return {"success": True, "output": numbered, "error": ""} except Exception as e: return {"success": False, "output": "", "error": str(e)} @app.post("/api/write") def write(req: WriteReq): try: p = pathlib.Path(req.path) _atomic_write(p, req.content) msg = f"Wrote {len(req.content)} bytes to {req.path}" if p.suffix == ".py": warnings = _validate_python(req.content, req.path) if warnings: msg += "\\n\\nValidation warnings:\\n" + "\\n".join(f" ! {w}" for w in warnings) return {"success": True, "output": msg, "error": ""} except Exception as e: return {"success": False, "output": "", "error": str(e)} @app.post("/api/edit") def edit(req: EditReq): try: p = pathlib.Path(req.path) if not p.exists(): return {"success": False, "output": "", "error": f"File not found: {req.path}"} content = p.read_text() if req.old_str == req.new_str: return {"success": False, "output": "", "error": "old_str and new_str must differ."} try: new_content, count, fuzzy_note = _apply_edit( content, req.old_str, req.new_str, mode=req.mode, replace_all=req.replace_all ) except ValueError as e: return {"success": False, "output": "", "error": str(e)} _atomic_write(p, new_content) msg = f"Edited {req.path} ({count} replacement{'s' if count > 1 else ''})" if fuzzy_note: msg += f" {fuzzy_note}" if p.suffix == ".py": warnings = _validate_python(new_content, req.path) if warnings: msg += "\\n\\nValidation warnings:\\n" + "\\n".join(f" ! {w}" for w in warnings) return {"success": True, "output": msg, "error": ""} except Exception as e: return {"success": False, "output": "", "error": str(e)} @app.post("/api/exists") def exists(req: ExistsReq): return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860) ''' @dataclass class ToolResult: success: bool output: str = "" error: str = "" def __str__(self): if self.success: return self.output or "(no output)" return f"ERROR: {self.error}" def to_dict(self) -> dict: return {"success": self.success, "output": self.output, "error": self.error} @dataclass class Sandbox: """ A handle to an HF Space sandbox. Use Sandbox.create() to spin up a new one, or Sandbox.connect() to attach to an existing running Space. """ space_id: str token: str | None = None work_dir: str = "/app" timeout: int = DEFAULT_TIMEOUT _owns_space: bool = field(default=False, repr=False) _base_url: str = field(init=False, repr=False) _client: httpx.Client = field(init=False, repr=False) _hf_api: HfApi = field(init=False, repr=False) _files_read: set = field(init=False, repr=False, default_factory=set) def __post_init__(self): slug = self.space_id.replace("/", "-") # Trailing slash is critical: httpx resolves relative paths against base_url. # Without it, client.get("health") resolves to /health instead of /api/health. self._base_url = f"https://{slug}.hf.space/api/" self._client = httpx.Client( base_url=self._base_url, headers={"Authorization": f"Bearer {self.token}"} if self.token else {}, timeout=httpx.Timeout(MAX_TIMEOUT, connect=30), follow_redirects=True, ) self._hf_api = HfApi(token=self.token) # ── Lifecycle ───────────────────────────────────────────────── class Cancelled(Exception): """Raised when sandbox creation is cancelled by the user.""" @classmethod def create( cls, owner: str, *, name: str | None = None, template: str = TEMPLATE_SPACE, hardware: str = "cpu-basic", private: bool = False, sleep_time: int | None = None, token: str | None = None, secrets: dict[str, str] | None = None, wait_timeout: int = WAIT_TIMEOUT, log: "Callable[[str], object] | None" = None, cancel_event: "Any | None" = None, ) -> Sandbox: """ Create a new sandbox by duplicating the template Space. Generates a unique space name, duplicates the template, waits for it to come online, then returns a connected Sandbox. Args: owner: HF username or org (e.g. "burtenshaw"). name: Base name for the space. Defaults to "sandbox". A unique suffix is always appended. template: Source Space to duplicate (default: burtenshaw/sandbox). hardware: Hardware tier (cpu-basic, t4-small, etc.). private: Whether the Space should be private. sleep_time: Auto-sleep after N seconds of inactivity. token: HF API token (from user's OAuth session). wait_timeout: Max seconds to wait for Space to start (default: 300). cancel_event: A threading.Event (or compatible) checked during polling loops. When set, the Space is deleted and Sandbox.Cancelled is raised. Returns: A Sandbox instance connected to the running Space. """ _log = log or print api = HfApi(token=token) def _check_cancel(): if cancel_event and cancel_event.is_set(): _log("Sandbox creation cancelled by user, cleaning up...") try: api.delete_repo(space_id, repo_type="space") _log(f"Deleted Space {space_id}") except Exception: pass raise cls.Cancelled(f"Sandbox creation cancelled: {space_id}") base = name or "sandbox" suffix = uuid.uuid4().hex[:8] space_id = f"{owner}/{base}-{suffix}" _log(f"Creating sandbox: {space_id} (from {template})...") kwargs = { "from_id": template, "to_id": space_id, "private": private, "hardware": hardware, } if sleep_time is not None: kwargs["sleep_time"] = sleep_time api.duplicate_space(**kwargs) _log(f"Space created: https://huggingface.co/spaces/{space_id}") _check_cancel() # Inject secrets BEFORE uploading server files (which triggers rebuild). # Secrets added after a Space is running aren't available until restart, # so they must be set before the build/start cycle. if secrets: for key, val in secrets.items(): api.add_space_secret(space_id, key, val) # Upload sandbox server and Dockerfile (triggers rebuild) cls._setup_server(space_id, api, log=_log) _check_cancel() # Wait for it to come online (rebuild + start) _log(f"Waiting for Space to start (timeout: {wait_timeout}s)...") deadline = time.time() + wait_timeout while time.time() < deadline: _check_cancel() runtime = api.get_space_runtime(space_id) if runtime.stage == "RUNNING": _log(f"Space is running (hardware: {runtime.hardware})") break if runtime.stage in ("RUNTIME_ERROR", "BUILD_ERROR"): raise RuntimeError( f"Space failed to start: {runtime.stage}. " f"Check https://huggingface.co/spaces/{space_id}" ) _log(f" {runtime.stage}...") time.sleep(WAIT_INTERVAL) else: raise TimeoutError( f"Space did not start within {wait_timeout}s. " f"Check https://huggingface.co/spaces/{space_id}" ) _check_cancel() # Wait for the API server to be responsive (non-fatal) sb = cls(space_id=space_id, token=token, _owns_space=True) try: sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log) except TimeoutError as e: _log( f"Warning: API health check timed out ({e}), but Space is RUNNING. Continuing." ) return sb @staticmethod def _setup_server(space_id: str, api: HfApi, *, log: Callable[[str], object] = print) -> None: """Upload embedded sandbox server + Dockerfile to the Space (single commit).""" log(f"Uploading sandbox server to {space_id}...") api.create_commit( repo_id=space_id, repo_type="space", operations=[ CommitOperationAdd( path_in_repo="sandbox_server.py", path_or_fileobj=io.BytesIO(_SANDBOX_SERVER.encode()), ), CommitOperationAdd( path_in_repo="Dockerfile", path_or_fileobj=io.BytesIO(_DOCKERFILE.encode()), ), ], commit_message="Setup sandbox server", ) log("Server files uploaded, rebuild triggered.") @classmethod def connect(cls, space_id: str, *, token: str | None = None) -> Sandbox: """ Connect to an existing running Space. Does a health check to verify the Space is reachable. """ sb = cls(space_id=space_id, token=token, _owns_space=False) sb._wait_for_api(timeout=60) return sb def _wait_for_api(self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print): """Poll the health endpoint until the server responds.""" deadline = time.time() + timeout last_err = None last_status = None while time.time() < deadline: try: resp = self._client.get("health", timeout=10) last_status = resp.status_code if resp.status_code == 200: log(f"API is responsive at {self._base_url}") return except Exception as e: last_err = e time.sleep(3) raise TimeoutError( f"Sandbox API at {self._base_url} not responding after {timeout}s. " f"Last status: {last_status}, last error: {last_err}" ) def delete(self): """Delete the Space. Only works if this Sandbox created it.""" if not self._owns_space: raise RuntimeError( f"This Sandbox did not create {self.space_id}. " f"Use self._hf_api.delete_repo() directly if you're sure." ) print(f"Deleting sandbox: {self.space_id}...") self._hf_api.delete_repo(self.space_id, repo_type="space") self._client.close() print("Deleted.") def pause(self): """Pause the Space (stops billing, preserves state).""" self._hf_api.pause_space(self.space_id) def restart(self): """Restart the Space.""" self._hf_api.restart_space(self.space_id) self._wait_for_api() @property def url(self) -> str: """Public URL of the Space.""" return f"https://huggingface.co/spaces/{self.space_id}" @property def status(self) -> str: """Current Space stage (RUNNING, BUILDING, PAUSED, etc.).""" return self._hf_api.get_space_runtime(self.space_id).stage def __enter__(self) -> Sandbox: return self def __exit__(self, *exc): if self._owns_space: try: self.delete() except Exception as e: print(f"Warning: failed to delete sandbox: {e}", file=sys.stderr) self._client.close() # ── HTTP plumbing ───────────────────────────────────────────── def _call( self, endpoint: str, payload: dict, timeout: float | None = None ) -> ToolResult: # Strip leading slash for correct httpx base_url resolution endpoint = endpoint.lstrip("/") effective_timeout = timeout or self.timeout last_error = "" # Retry up to 3 times for transient failures (sandbox waking from # sleep returns empty / non-JSON responses while it starts up). for attempt in range(3): try: resp = self._client.post( endpoint, json=payload, timeout=effective_timeout, ) try: data = resp.json() except (ValueError, UnicodeDecodeError): # Non-JSON response — sandbox is likely still starting up. body_preview = resp.text[:200] if resp.text else "(empty)" last_error = ( f"Sandbox returned non-JSON response (HTTP {resp.status_code}): " f"{body_preview}" ) if attempt < 2: time.sleep(3 * (attempt + 1)) continue return ToolResult(success=False, error=last_error) if resp.status_code == 200: return ToolResult( success=data.get("success", True), output=data.get("output", ""), error=data.get("error", ""), ) return ToolResult( success=False, error=data.get("error", f"HTTP {resp.status_code}"), ) except httpx.TimeoutException: return ToolResult( success=False, error=f"Timeout after {effective_timeout}s" ) except httpx.ConnectError: last_error = ( f"Cannot connect to sandbox. Is {self.space_id} running? " f"Status: {self.status}" ) if attempt < 2: time.sleep(3 * (attempt + 1)) continue return ToolResult(success=False, error=last_error) except Exception as e: return ToolResult(success=False, error=str(e)) return ToolResult(success=False, error=last_error or "Unknown error") # ── Tools ───────────────────────────────────────────────────── def bash( self, command: str, *, work_dir: str | None = None, timeout: int | None = None, description: str | None = None, ) -> ToolResult: return self._call( "bash", { "command": command, "work_dir": work_dir or self.work_dir, "timeout": min(timeout or self.timeout, MAX_TIMEOUT), }, timeout=timeout, ) def read( self, path: str, *, offset: int | None = None, limit: int | None = None ) -> ToolResult: self._files_read.add(path) return self._call( "read", { "path": path, "offset": offset, "limit": limit or (DEFAULT_READ_LIMIT if offset is None else None), }, ) def write(self, path: str, content: str) -> ToolResult: if path not in self._files_read: check = self._call("exists", {"path": path}) if check.success and check.output == "true": return ToolResult( success=False, error=( f"File {path} exists but has not been read this session. " f"Read it first, or use sandbox_edit for targeted changes." ), ) result = self._call("write", {"path": path, "content": content}) if result.success: self._files_read.add(path) return result def edit( self, path: str, old_str: str, new_str: str, *, replace_all: bool = False, mode: str = "replace", ) -> ToolResult: if old_str == new_str: return ToolResult(success=False, error="old_str and new_str are identical.") if path not in self._files_read: return ToolResult( success=False, error=f"File {path} has not been read this session. Read it first.", ) return self._call( "edit", { "path": path, "old_str": old_str, "new_str": new_str, "replace_all": replace_all, "mode": mode, }, ) def kill_all(self) -> ToolResult: """Kill all active bash processes on the sandbox. Used on cancellation.""" return self._call("kill", {}) # ── Tool schemas & dispatch ─────────────────────────────────── TOOLS = { "bash": { "description": ( "Run a shell command in the remote sandbox and return stdout/stderr.\n" "\n" "IMPORTANT: Do NOT use bash for file operations — use the dedicated tools instead:\n" "- To read files: use read (not cat/head/tail)\n" "- To edit files: use edit (not sed/awk)\n" "- To write files: use write (not echo/cat < > /app/output.log 2>&1 & echo $!\n" "Then check status:\n" " kill -0 2>/dev/null && echo 'running' || echo 'done'\n" " tail -n 50 /app/output.log\n" "\n" "Timeout default 240s, max 1200s." ), "parameters": { "type": "object", "required": ["command"], "additionalProperties": False, "properties": { "command": { "type": "string", "description": "The shell command to execute.", }, "description": { "type": "string", "description": "Short description (5-10 words, active voice).", }, "work_dir": { "type": "string", "description": "Working directory (default: /app).", }, "timeout": { "type": "integer", "description": "Optional timeout in seconds (default: 240, max: 1200).", }, }, }, }, "read": { "description": ( "Reads a file from the sandbox filesystem. Returns contents with line " "numbers (cat -n format).\n" "\n" "Usage:\n" "- By default, reads up to 2000 lines from the beginning of the file.\n" "- You can optionally specify offset and limit for large files, but prefer " "reading the whole file first.\n" "- Lines longer than 4000 chars are truncated.\n" "- Cannot read directories — use bash with 'ls' instead.\n" "- You should read multiple potentially useful files in parallel when possible.\n" "- IMPORTANT: Always read a file before editing or overwriting it. The edit and " "write tools will reject operations on files you haven't read." ), "parameters": { "type": "object", "required": ["path"], "additionalProperties": False, "properties": { "path": { "type": "string", "description": "Absolute path to the file to read.", }, "offset": { "type": "integer", "description": "The line number to start reading from (1-based). Only provide if the file is too large to read at once.", }, "limit": { "type": "integer", "description": "The number of lines to read. Only provide if the file is too large to read at once.", }, }, }, }, "write": { "description": ( "Writes a file to the sandbox filesystem. Overwrites the existing file if " "one exists at the path.\n" "\n" "- If this is an existing file, you MUST use the read tool first. This tool " "will fail if you did not read the file first.\n" "- ALWAYS prefer editing existing files with the edit tool over overwriting " "with write.\n" "- Creates parent directories as needed." ), "parameters": { "type": "object", "required": ["path", "content"], "additionalProperties": False, "properties": { "path": { "type": "string", "description": "Absolute path to the file to write.", }, "content": { "type": "string", "description": "The complete file content to write.", }, }, }, }, "edit": { "description": ( "Performs string replacements in files. Supports exact matching with " "fuzzy fallback.\n" "\n" "Usage:\n" "- You must read the file at least once before editing. This tool will " "error if you attempt an edit without reading the file.\n" "- The edit will FAIL if old_str is not unique in the file. Either provide " "a larger string with more surrounding context to make it unique, or set " "replace_all to true.\n" "- old_str and new_str must differ.\n" "- Preserve indentation exactly as it appears in the file.\n" "- Do NOT include line number prefixes from read output in old_str or new_str.\n" "- To delete code, set new_str to empty string.\n" "- Use replace_all for renaming variables or strings across the file.\n" "\n" "Modes:\n" "- replace (default): replace first occurrence of old_str with new_str.\n" "- append_after: insert new_str immediately after old_str (old_str is kept).\n" "- prepend_before: insert new_str immediately before old_str (old_str is kept)." ), "parameters": { "type": "object", "required": ["path", "old_str", "new_str"], "additionalProperties": False, "properties": { "path": { "type": "string", "description": "Absolute path to the file to edit.", }, "old_str": { "type": "string", "description": "The text to find in the file. Must match exactly (fuzzy matching is used as fallback).", }, "new_str": { "type": "string", "description": "The replacement text. For append_after/prepend_before modes, the text to insert.", }, "replace_all": { "type": "boolean", "description": "Replace all occurrences of old_str (default: false).", "default": False, }, "mode": { "type": "string", "enum": ["replace", "append_after", "prepend_before"], "description": "Edit mode (default: replace).", "default": "replace", }, }, }, }, } @classmethod def tool_definitions(cls) -> list[dict]: return [{"name": name, **spec} for name, spec in cls.TOOLS.items()] def call_tool(self, name: str, arguments: dict[str, Any]) -> ToolResult: dispatch = { "bash": lambda a: self.bash( a["command"], work_dir=a.get("work_dir"), timeout=a.get("timeout"), description=a.get("description"), ), "read": lambda a: self.read( a["path"], offset=a.get("offset"), limit=a.get("limit"), ), "write": lambda a: self.write(a["path"], a["content"]), "edit": lambda a: self.edit( a["path"], a["old_str"], a["new_str"], replace_all=a.get("replace_all", False), mode=a.get("mode", "replace"), ), } fn = dispatch.get(name) if not fn: return ToolResult(success=False, error=f"Unknown tool: {name}") return fn(arguments) ================================================ FILE: agent/tools/sandbox_tool.py ================================================ """ Sandbox tools — expose the Sandbox client as agent tools. 5 tools total: sandbox_create — explicit sandbox creation (requires approval) bash, read, write, edit — operations on the sandbox If any operation tool is called without an active sandbox, a cpu-basic sandbox is auto-created (no approval needed). """ from __future__ import annotations import asyncio import threading from typing import Any from huggingface_hub import HfApi, SpaceHardware from agent.core.session import Event from agent.tools.sandbox_client import Sandbox def _looks_like_path(script: str) -> bool: """Return True if the script string looks like a file path (not inline code).""" return ( isinstance(script, str) and script.strip() == script and not any(c in script for c in "\r\n\0") and ( script.startswith("/") or script.startswith("./") or script.startswith("../") ) ) async def resolve_sandbox_script( sandbox: Any, script: str ) -> tuple[str | None, str | None]: """Read a file from the sandbox if *script* looks like a path. Returns: (content, error) — content is the file text on success, error is a message on failure. Both None means *script* is not a path (caller should use it as-is). """ if not sandbox or not _looks_like_path(script): return None, None try: # Use the read endpoint instead of bash("cat ...") which truncates at 25KB. result = await asyncio.to_thread(sandbox.read, script, limit=100_000) if result.success and result.output: # Strip line number prefixes (read returns "N\tcontent" format) lines = [] for line in result.output.split("\n"): parts = line.split("\t", 1) lines.append(parts[1] if len(parts) == 2 else line) return "\n".join(lines), None return None, f"Failed to read {script} from sandbox: {result.error}" except Exception as e: return None, f"Failed to read {script} from sandbox: {e}" # ── Tool name mapping (short agent names → Sandbox client names) ────── async def _ensure_sandbox( session: Any, hardware: str = "cpu-basic", **create_kwargs ) -> tuple[Sandbox | None, str | None]: """ Ensure a sandbox exists on the session. Auto-creates with given hardware if needed. Returns: (sandbox, error_message) — one will be None. """ if session and getattr(session, "sandbox", None): return session.sandbox, None if not session: return None, "No session available." token = session.hf_token if not token: return None, "No HF token available. Cannot create sandbox." api = HfApi(token=token) user_info = api.whoami() owner = user_info.get("name", user_info.get("user", "")) if not owner: return None, "Could not determine HF username from token." await session.send_event( Event( event_type="tool_log", data={ "tool": "sandbox", "log": f"Auto-creating sandbox for {owner} ({hardware})...", }, ) ) # Thread-safe log callback: posts tool_log events from the worker thread loop = asyncio.get_running_loop() def _log(msg: str) -> None: loop.call_soon_threadsafe( session.event_queue.put_nowait, Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}), ) # Bridge asyncio cancel event to a threading.Event for the blocking create call. # We poll session._cancelled from the main loop in a background task and set # a threading.Event that Sandbox.create checks during its polling loops. cancel_flag = threading.Event() async def _watch_cancel(): await session._cancelled.wait() cancel_flag.set() watcher_task = asyncio.create_task(_watch_cancel()) kwargs = { "owner": owner, "hardware": hardware, "token": token, "secrets": {"HF_TOKEN": token}, "log": _log, "cancel_event": cancel_flag, **create_kwargs, } if hardware != "cpu-basic": kwargs["sleep_time"] = 2700 try: sb = await asyncio.to_thread(Sandbox.create, **kwargs) except Sandbox.Cancelled: return None, "Sandbox creation cancelled by user." finally: watcher_task.cancel() session.sandbox = sb # Set a descriptive title (template title is inherited on duplicate) from huggingface_hub import metadata_update await asyncio.to_thread( metadata_update, sb.space_id, {"title": "ml-intern sandbox"}, repo_type="space", overwrite=True, token=token, ) await session.send_event( Event( event_type="tool_log", data={"tool": "sandbox", "log": f"Sandbox ready: {sb.space_id} ({sb.url})"}, ) ) return sb, None # ── sandbox_create tool ────────────────────────────────────────────── SANDBOX_CREATE_TOOL_SPEC = { "name": "sandbox_create", "description": ( "Create a persistent remote Linux environment for developing and testing scripts.\n\n" "Workflow: sandbox_create → write script → pip install → test with small run → fix errors → hf_jobs at scale.\n" "The sandbox persists across tool calls within the session. pip install works out of the box.\n\n" "Use this when: you need to develop, test, and iterate on scripts before launching via hf_jobs. " "Especially for training scripts where you need to verify imports, test on a small subset, and fix errors interactively.\n\n" "Skip this when: the task is a simple one-shot operation (status check, resource search, quick data query), " "or the script is copied from a verified working example with minimal changes.\n\n" "For ML code that uses CUDA, bf16, or model loading: use GPU hardware (t4-small minimum). " "CPU sandboxes cannot run GPU code paths — your test will not catch GPU-related errors.\n\n" "Before choosing hardware, estimate your VRAM needs (models you run, training data size). Rule of thumb: bf16/fp16 ≈ 2 bytes/param, " "fp32 ≈ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n" "Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). " "If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n" "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n" ), "parameters": { "type": "object", "required": [], "additionalProperties": False, "properties": { "hardware": { "type": "string", "enum": [e.value for e in SpaceHardware], "description": "Hardware tier for the sandbox (default: cpu-basic)", }, "private": { "type": "boolean", "description": "If true, create a private Space", }, }, }, } async def sandbox_create_handler( args: dict[str, Any], session: Any = None ) -> tuple[str, bool]: """Handle sandbox_create tool calls.""" # If sandbox already exists, return its info if session and getattr(session, "sandbox", None): sb = session.sandbox return ( f"Sandbox already active: {sb.space_id}\n" f"URL: {sb.url}\n" f"Use bash/read/write/edit to interact with it." ), True hardware = args.get("hardware", "cpu-basic") create_kwargs = {} if "private" in args: create_kwargs["private"] = args["private"] try: sb, error = await _ensure_sandbox(session, hardware=hardware, **create_kwargs) except Exception as e: return f"Failed to create sandbox: {e}", False if error: return error, False return ( f"Sandbox created: {sb.space_id}\n" f"URL: {sb.url}\n" f"Hardware: {hardware}\n" f"Use bash/read/write/edit to interact with it." ), True def _make_tool_handler(sandbox_tool_name: str): """Factory: create a handler for a sandbox operation tool.""" async def handler(args: dict[str, Any], session: Any = None) -> tuple[str, bool]: # Require sandbox to exist — user must approve sandbox_create first if not session or not getattr(session, "sandbox", None): return "No sandbox running. Call sandbox_create first to start one.", False sb = session.sandbox try: result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args) if result.success: output = result.output or "(no output)" return output, True else: error_msg = result.error or "Unknown error" output = result.output if output: return f"{output}\n\nERROR: {error_msg}", False return f"ERROR: {error_msg}", False except Exception as e: return f"Sandbox operation failed: {e}", False return handler def get_sandbox_tools(): """Return all 5 sandbox ToolSpecs (sandbox_create + 4 operation tools).""" from agent.core.tools import ToolSpec tools = [] # sandbox_create (explicit creation, requires approval) tools.append( ToolSpec( name=SANDBOX_CREATE_TOOL_SPEC["name"], description=SANDBOX_CREATE_TOOL_SPEC["description"], parameters=SANDBOX_CREATE_TOOL_SPEC["parameters"], handler=sandbox_create_handler, ) ) # Operation tools (auto-execute, no approval needed) for name in Sandbox.TOOLS.keys(): spec = Sandbox.TOOLS[name] tools.append( ToolSpec( name=name, description=spec["description"], parameters=spec["parameters"], handler=_make_tool_handler(name), ) ) return tools ================================================ FILE: agent/tools/types.py ================================================ """ Types for Hugging Face tools Ported from: hf-mcp-server/packages/mcp/src/types/ """ from typing import TypedDict class ToolResult(TypedDict, total=False): """Result returned by HF tool operations""" formatted: str totalResults: int resultsShared: int isError: bool ================================================ FILE: agent/tools/utilities.py ================================================ """ Utility functions for Hugging Face tools Ported from: hf-mcp-server/packages/mcp/src/jobs/formatters.ts Includes GPU memory validation for job submissions """ import json from datetime import datetime from typing import Any, Dict, List, Optional def truncate(text: str, max_length: int) -> str: """Truncate a string to a maximum length with ellipsis""" if len(text) <= max_length: return text return text[: max_length - 3] + "..." def format_date(date_str: Optional[str]) -> str: """Format a date string to a readable format""" if not date_str: return "N/A" try: date = datetime.fromisoformat(date_str.replace("Z", "+00:00")) return date.strftime("%Y-%m-%d %H:%M:%S") except Exception: return date_str def format_command(command: Optional[List[str]]) -> str: """Format command array as a single string""" if not command or len(command) == 0: return "N/A" return " ".join(command) def get_image_or_space(job: Dict[str, Any]) -> str: """Get image/space identifier from job""" if job.get("spaceId"): return job["spaceId"] if job.get("dockerImage"): return job["dockerImage"] return "N/A" def format_jobs_table(jobs: List[Dict[str, Any]]) -> str: """Format jobs as a markdown table""" if len(jobs) == 0: return "No jobs found." # Calculate dynamic ID column width longest_id_length = max(len(job["id"]) for job in jobs) id_column_width = max(longest_id_length, len("JOB ID")) # Define column widths col_widths = { "id": id_column_width, "image": 20, "command": 30, "created": 19, "status": 12, } # Build header header = f"| {'JOB ID'.ljust(col_widths['id'])} | {'IMAGE/SPACE'.ljust(col_widths['image'])} | {'COMMAND'.ljust(col_widths['command'])} | {'CREATED'.ljust(col_widths['created'])} | {'STATUS'.ljust(col_widths['status'])} |" separator = f"|{'-' * (col_widths['id'] + 2)}|{'-' * (col_widths['image'] + 2)}|{'-' * (col_widths['command'] + 2)}|{'-' * (col_widths['created'] + 2)}|{'-' * (col_widths['status'] + 2)}|" # Build rows rows = [] for job in jobs: job_id = job["id"] image = truncate(get_image_or_space(job), col_widths["image"]) command = truncate(format_command(job.get("command")), col_widths["command"]) created = truncate(format_date(job.get("createdAt")), col_widths["created"]) status = truncate(job["status"]["stage"], col_widths["status"]) rows.append( f"| {job_id.ljust(col_widths['id'])} | {image.ljust(col_widths['image'])} | {command.ljust(col_widths['command'])} | {created.ljust(col_widths['created'])} | {status.ljust(col_widths['status'])} |" ) return "\n".join([header, separator] + rows) def format_scheduled_jobs_table(jobs: List[Dict[str, Any]]) -> str: """Format scheduled jobs as a markdown table""" if len(jobs) == 0: return "No scheduled jobs found." # Calculate dynamic ID column width longest_id_length = max(len(job["id"]) for job in jobs) id_column_width = max(longest_id_length, len("ID")) # Define column widths col_widths = { "id": id_column_width, "schedule": 12, "image": 18, "command": 25, "lastRun": 19, "nextRun": 19, "suspend": 9, } # Build header header = f"| {'ID'.ljust(col_widths['id'])} | {'SCHEDULE'.ljust(col_widths['schedule'])} | {'IMAGE/SPACE'.ljust(col_widths['image'])} | {'COMMAND'.ljust(col_widths['command'])} | {'LAST RUN'.ljust(col_widths['lastRun'])} | {'NEXT RUN'.ljust(col_widths['nextRun'])} | {'SUSPENDED'.ljust(col_widths['suspend'])} |" separator = f"|{'-' * (col_widths['id'] + 2)}|{'-' * (col_widths['schedule'] + 2)}|{'-' * (col_widths['image'] + 2)}|{'-' * (col_widths['command'] + 2)}|{'-' * (col_widths['lastRun'] + 2)}|{'-' * (col_widths['nextRun'] + 2)}|{'-' * (col_widths['suspend'] + 2)}|" # Build rows rows = [] for job in jobs: job_id = job["id"] schedule = truncate(job["schedule"], col_widths["schedule"]) image = truncate(get_image_or_space(job["jobSpec"]), col_widths["image"]) command = truncate( format_command(job["jobSpec"].get("command")), col_widths["command"] ) last_run = truncate(format_date(job.get("lastRun")), col_widths["lastRun"]) next_run = truncate(format_date(job.get("nextRun")), col_widths["nextRun"]) suspend = "Yes" if job.get("suspend") else "No" rows.append( f"| {job_id.ljust(col_widths['id'])} | {schedule.ljust(col_widths['schedule'])} | {image.ljust(col_widths['image'])} | {command.ljust(col_widths['command'])} | {last_run.ljust(col_widths['lastRun'])} | {next_run.ljust(col_widths['nextRun'])} | {suspend.ljust(col_widths['suspend'])} |" ) return "\n".join([header, separator] + rows) def format_job_details(jobs: Any) -> str: """Format job details as JSON in a markdown code block""" job_array = jobs if isinstance(jobs, list) else [jobs] json_str = json.dumps(job_array, indent=2) return f"```json\n{json_str}\n```" def format_scheduled_job_details(jobs: Any) -> str: """Format scheduled job details as JSON in a markdown code block""" job_array = jobs if isinstance(jobs, list) else [jobs] json_str = json.dumps(job_array, indent=2) return f"```json\n{json_str}\n```" ================================================ FILE: agent/utils/__init__.py ================================================ """ Utility functions and helpers """ ================================================ FILE: agent/utils/boot_timing.py ================================================ """Shared timing and color helpers for startup visual effects.""" import math def settle_curve(progress: float, sharpness: float = 3.0) -> float: """Return noise amount in range 1..0 for normalized progress 0..1.""" t = max(0.0, min(1.0, progress)) return math.exp(-sharpness * t) def warm_gold_from_white(progress: float) -> tuple[int, int, int]: """Interpolate from white to warm gold for progress 0..1.""" t = max(0.0, min(1.0, progress)) return 255, int(255 - 55 * t), int(255 - 175 * t) ================================================ FILE: agent/utils/braille.py ================================================ """Braille-character canvas for high-resolution terminal graphics. Each terminal cell maps to a 2x4 dot grid using Unicode braille characters (U+2800–U+28FF), giving 2× horizontal and 4× vertical resolution. """ # Braille dot positions: (0,0) (1,0) dots 1,4 # (0,1) (1,1) dots 2,5 # (0,2) (1,2) dots 3,6 # (0,3) (1,3) dots 7,8 _DOT_MAP = ( (0x01, 0x08), (0x02, 0x10), (0x04, 0x20), (0x40, 0x80), ) class BrailleCanvas: """A pixel canvas that renders to braille characters.""" def __init__(self, term_width: int, term_height: int): self.term_width = term_width self.term_height = term_height self.pixel_width = term_width * 2 self.pixel_height = term_height * 4 self._buf = bytearray(term_width * term_height) def clear(self) -> None: for i in range(len(self._buf)): self._buf[i] = 0 def set_pixel(self, x: int, y: int) -> None: if 0 <= x < self.pixel_width and 0 <= y < self.pixel_height: cx, rx = divmod(x, 2) cy, ry = divmod(y, 4) self._buf[cy * self.term_width + cx] |= _DOT_MAP[ry][rx] def render(self) -> list[str]: lines = [] for row in range(self.term_height): offset = row * self.term_width line = "".join( chr(0x2800 + self._buf[offset + col]) for col in range(self.term_width) ) lines.append(line) return lines # ── Bitmap font (5×7 uppercase + digits) ────────────────────────────── _FONT: dict[str, list[str]] = {} def _define_font() -> None: """Define a simple 5×7 bitmap font for uppercase ASCII.""" glyphs = { "A": [" ## ", "# #", "# #", "####", "# #", "# #", "# #"], "B": ["### ", "# #", "# #", "### ", "# #", "# #", "### "], "C": [" ## ", "# #", "# ", "# ", "# ", "# #", " ## "], "D": ["### ", "# #", "# #", "# #", "# #", "# #", "### "], "E": ["####", "# ", "# ", "### ", "# ", "# ", "####"], "F": ["####", "# ", "# ", "### ", "# ", "# ", "# "], "G": [" ## ", "# #", "# ", "# ##", "# #", "# #", " ###"], "H": ["# #", "# #", "# #", "####", "# #", "# #", "# #"], "I": ["###", " # ", " # ", " # ", " # ", " # ", "###"], "J": [" ##", " # ", " # ", " # ", " # ", "# # ", " # "], "K": ["# #", "# # ", "## ", "## ", "# # ", "# #", "# #"], "L": ["# ", "# ", "# ", "# ", "# ", "# ", "####"], "M": ["# #", "## ##", "# # #", "# # #", "# #", "# #", "# #"], "N": ["# #", "## #", "## #", "# ##", "# ##", "# #", "# #"], "O": [" ## ", "# #", "# #", "# #", "# #", "# #", " ## "], "P": ["### ", "# #", "# #", "### ", "# ", "# ", "# "], "Q": [" ## ", "# #", "# #", "# #", "# ##", "# #", " ## "], "R": ["### ", "# #", "# #", "### ", "# # ", "# #", "# #"], "S": [" ## ", "# #", "# ", " ## ", " #", "# #", " ## "], "T": ["#####", " # ", " # ", " # ", " # ", " # ", " # "], "U": ["# #", "# #", "# #", "# #", "# #", "# #", " ## "], "V": ["# #", "# #", "# #", " # # ", " # # ", " # ", " # "], "W": ["# #", "# #", "# #", "# # #", "# # #", "## ##", "# #"], "X": ["# #", "# #", " ## ", " ## ", " ## ", "# #", "# #"], "Y": ["# #", "# #", " # # ", " # ", " # ", " # ", " # "], "Z": ["####", " #", " # ", " # ", "# ", "# ", "####"], " ": [" ", " ", " ", " ", " ", " ", " "], "0": [" ## ", "# #", "# #", "# #", "# #", "# #", " ## "], "1": [" # ", "## ", " # ", " # ", " # ", " # ", "###"], "2": [" ## ", "# #", " #", " # ", " # ", "# ", "####"], "3": [" ## ", "# #", " #", " ## ", " #", "# #", " ## "], "4": ["# #", "# #", "# #", "####", " #", " #", " #"], "5": ["####", "# ", "### ", " #", " #", "# #", " ## "], "6": [" ## ", "# ", "### ", "# #", "# #", "# #", " ## "], "7": ["####", " #", " # ", " # ", " # ", " # ", " # "], "8": [" ## ", "# #", "# #", " ## ", "# #", "# #", " ## "], "9": [" ## ", "# #", "# #", " ###", " #", " #", " ## "], } _FONT.update(glyphs) _define_font() def text_to_pixels(text: str, scale: int = 1) -> list[tuple[int, int]]: """Convert text string to a list of (x, y) pixel positions using bitmap font.""" pixels = [] cursor_x = 0 for ch in text.upper(): glyph = _FONT.get(ch) if glyph is None: cursor_x += 4 * scale continue for row_idx, row in enumerate(glyph): for col_idx, cell in enumerate(row): if cell == "#": for sy in range(scale): for sx in range(scale): pixels.append((cursor_x + col_idx * scale + sx, row_idx * scale + sy)) glyph_width = max(len(r) for r in glyph) cursor_x += (glyph_width + 1) * scale return pixels ================================================ FILE: agent/utils/crt_boot.py ================================================ """CRT / glitch boot sequence effect for CLI startup. Simulates an old CRT terminal booting up: text appearing character by character with noise artifacts, then settling into a clean display. """ import random import time from rich.console import Console from rich.text import Text from rich.live import Live from agent.utils.boot_timing import settle_curve def _glitch_text(text: str, intensity: float, rng: random.Random) -> str: """Add random glitch characters to text.""" glitch_chars = "█▓▒░┃┫┣╋╏╎─━┅┄" result = list(text) for i in range(len(result)): if rng.random() < intensity: result[i] = rng.choice(glitch_chars) return "".join(result) def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> None: """Run the CRT boot sequence effect. Args: console: Rich console instance. boot_lines: List of (text, rich_style) tuples to display. """ term_height = min(console.height - 2, 40) rng = random.Random(42) with Live(console=console, refresh_per_second=30, transient=True) as live: displayed_lines: list[tuple[str, str]] = [] for line_text, line_style in boot_lines: if not line_text: displayed_lines.append(("", "")) continue line_len = max(1, len(line_text)) # Type out each character for char_idx in range(len(line_text) + 1): result = Text() progress = char_idx / line_len noise = settle_curve(progress) prev_glitch_chance = 0.01 + 0.06 * noise prev_glitch_intensity = 0.02 + 0.12 * noise scanline_chance = 0.005 + 0.03 * noise # Render previously completed lines for prev_text, prev_style in displayed_lines: if rng.random() < prev_glitch_chance: result.append(_glitch_text(prev_text, prev_glitch_intensity, rng), style=prev_style) else: result.append(prev_text, style=prev_style) result.append("\n") # Current line being typed typed = line_text[:char_idx] cursor = "█" if char_idx < len(line_text) else "" # Noise after cursor noise_tail = "" if char_idx < len(line_text): noise_len = rng.randint(0, int(1 + 5 * noise)) noise_tail = "".join(rng.choice("░▒▓") for _ in range(noise_len)) result.append(typed, style=line_style) result.append(cursor, style="bold rgb(255,200,80)") result.append(noise_tail, style="dim rgb(180,140,40)") result.append("\n") # Faint scanlines in remaining space remaining = term_height - len(displayed_lines) - 2 for _ in range(max(0, remaining)): if rng.random() < scanline_chance: scan_len = rng.randint(5, 30) result.append("─" * scan_len, style="dim rgb(180,140,40)") result.append("\n") live.update(result) # Variable typing speed if line_text[char_idx - 1:char_idx] in " .": time.sleep(0.025) else: time.sleep(0.010) displayed_lines.append((line_text, line_style)) time.sleep(0.06) # Hold with blinking cursor for frame in range(20): result = Text() for prev_text, prev_style in displayed_lines: result.append(prev_text, style=prev_style) result.append("\n") if frame % 8 < 4: result.append("█", style="rgb(255,200,80)") live.update(result) time.sleep(0.05) # Print final clean frame final = Text() for prev_text, prev_style in displayed_lines: final.append(prev_text, style=prev_style) final.append("\n") console.print(final) ================================================ FILE: agent/utils/particle_logo.py ================================================ """Particle coalesce effect for the HUGGING FACE ML INTERN logo. Random particles swirl in from the edges, converge to form the text "HUGGING FACE / ML INTERN", hold briefly, then the final frame is printed. Rendered with braille characters for high detail. Based on Leandro's particle_coalesce.py demo. """ import math import random import time from rich.console import Console from rich.text import Text from rich.align import Align from rich.live import Live from agent.utils.braille import BrailleCanvas, text_to_pixels from agent.utils.boot_timing import settle_curve, warm_gold_from_white class Particle: __slots__ = ("x", "y", "target_x", "target_y", "vx", "vy", "phase", "delay") def __init__(self, x: float, y: float, target_x: float, target_y: float, delay: float = 0): self.x = x self.y = y self.target_x = target_x self.target_y = target_y self.vx = 0.0 self.vy = 0.0 self.phase = random.uniform(0, math.pi * 2) self.delay = delay def update_converge(self, t: float, strength: float = 0.08, damping: float = 0.92): """Move toward target with spring-like physics.""" if t < self.delay: # Still in swirl phase self.x += self.vx self.y += self.vy self.vx *= 0.99 self.vy *= 0.99 # Gentle spiral angle = self.phase + t * 2 self.vx += math.cos(angle) * 0.3 self.vy += math.sin(angle) * 0.3 return # Spring toward target dx = self.target_x - self.x dy = self.target_y - self.y self.vx += dx * strength self.vy += dy * strength self.vx *= damping self.vy *= damping self.x += self.vx self.y += self.vy @property def at_target(self) -> bool: return abs(self.x - self.target_x) < 1.5 and abs(self.y - self.target_y) < 1.5 def run_particle_logo(console: Console, hold_seconds: float = 1.5) -> None: """Run the particle coalesce effect.""" term_width = min(console.width, 120) term_height = min(console.height - 4, 35) canvas = BrailleCanvas(term_width, term_height) # Get target positions from text text_pixels_line1 = text_to_pixels("HUGGING FACE", scale=2) text_pixels_line2 = text_to_pixels("ML INTERN", scale=2) # Calculate dimensions for centering def get_bounds(pixels): if not pixels: return 0, 0, 0, 0 xs = [p[0] for p in pixels] ys = [p[1] for p in pixels] return min(xs), max(xs), min(ys), max(ys) min_x1, max_x1, min_y1, max_y1 = get_bounds(text_pixels_line1) min_x2, max_x2, min_y2, max_y2 = get_bounds(text_pixels_line2) w1, h1 = max_x1 - min_x1 + 1, max_y1 - min_y1 + 1 w2, h2 = max_x2 - min_x2 + 1, max_y2 - min_y2 + 1 total_h = h1 + 6 + h2 # gap between lines start_y = (canvas.pixel_height - total_h) // 2 # Center line 1 offset_x1 = (canvas.pixel_width - w1) // 2 - min_x1 offset_y1 = start_y - min_y1 targets_1 = [(p[0] + offset_x1, p[1] + offset_y1) for p in text_pixels_line1] # Center line 2 offset_x2 = (canvas.pixel_width - w2) // 2 - min_x2 offset_y2 = start_y + h1 + 6 - min_y2 targets_2 = [(p[0] + offset_x2, p[1] + offset_y2) for p in text_pixels_line2] all_targets = targets_1 + targets_2 # Subsample for performance — take every Nth pixel step = max(1, len(all_targets) // 1500) sampled_targets = all_targets[::step] # Create particles at random edge positions rng = random.Random(42) particles = [] pw, ph = canvas.pixel_width, canvas.pixel_height for i, (tx, ty) in enumerate(sampled_targets): # Spawn from random edge side = rng.choice(["top", "bottom", "left", "right"]) if side == "top": sx, sy = rng.uniform(0, pw), rng.uniform(-20, -5) elif side == "bottom": sx, sy = rng.uniform(0, pw), rng.uniform(ph + 5, ph + 20) elif side == "left": sx, sy = rng.uniform(-20, -5), rng.uniform(0, ph) else: sx, sy = rng.uniform(pw + 5, pw + 20), rng.uniform(0, ph) delay = rng.uniform(0, 0.4) # staggered start p = Particle(sx, sy, tx, ty, delay=delay) # Initial velocity — gentle swirl angle = math.atan2(ph / 2 - sy, pw / 2 - sx) + rng.gauss(0, 0.8) speed = rng.uniform(1.0, 2.5) p.vx = math.cos(angle) * speed p.vy = math.sin(angle) * speed particles.append(p) # Also add some extra ambient particles that never converge ambient = [] for _ in range(200): ax = rng.uniform(0, pw) ay = rng.uniform(0, ph) ap = Particle(ax, ay, ax, ay) ap.vx = rng.gauss(0, 1) ap.vy = rng.gauss(0, 1) ambient.append(ap) # Timing: 1s converge + 2s hold = 3s total fps = 24 converge_frames = int(fps * 0.9) hold_frames = int(fps * hold_seconds) total_frames = converge_frames + hold_frames with Live(console=console, refresh_per_second=fps, transient=True) as live: for frame in range(total_frames): canvas.clear() t = frame * 0.03 # Update ambient particles (always drifting) for ap in ambient: ap.x += ap.vx + math.sin(t + ap.phase) * 0.5 ap.y += ap.vy + math.cos(t + ap.phase * 1.3) * 0.5 # Wrap around ap.x = ap.x % pw ap.y = ap.y % ph # Fade out ambient during hold phase if frame < converge_frames: alpha = 0.3 + 0.2 * math.sin(t * 2 + ap.phase) else: fade = (frame - converge_frames) / hold_frames alpha = (0.3 + 0.2 * math.sin(t * 2 + ap.phase)) * (1 - fade) if alpha > 0.25: canvas.set_pixel(int(ap.x), int(ap.y)) if frame < converge_frames: # Converge phase progress = frame / converge_frames noise = settle_curve(progress) for p in particles: p.update_converge(t, strength=0.06, damping=0.90) canvas.set_pixel(int(p.x), int(p.y)) # Trail effect trail_scale = 0.2 + 0.5 * noise trail_x = int(p.x - p.vx * trail_scale) trail_y = int(p.y - p.vy * trail_scale) canvas.set_pixel(trail_x, trail_y) # Color transitions from white to warm gold r, g, b = warm_gold_from_white(progress) else: # Hold phase — settle into solid logo settle_t = (frame - converge_frames) / hold_frames for p in particles: # Jitter decays to zero jitter = (1 - settle_t) * 0.7 jx = p.target_x + math.sin(t * 3 + p.phase) * jitter jy = p.target_y + math.cos(t * 3 + p.phase * 1.5) * jitter canvas.set_pixel(int(jx), int(jy)) canvas.set_pixel(int(p.target_x), int(p.target_y)) r, g, b = 255, 200, 80 # Render with color lines = canvas.render() result = Text() for line in lines: for ch in line: if ch == chr(0x2800): result.append(ch) else: result.append(ch, style=f"rgb({r},{g},{b})") result.append("\n") live.update(Align.center(result)) time.sleep(1.0 / fps) # Print final settled frame canvas.clear() for p in particles: canvas.set_pixel(int(p.target_x), int(p.target_y)) final = Text() for line in canvas.render(): for ch in line: if ch == chr(0x2800): final.append(ch) else: final.append(ch, style="rgb(255,200,80)") final.append("\n") console.print(Align.center(final)) ================================================ FILE: agent/utils/reliability_checks.py ================================================ """Reliability checks for job submissions and other operations""" def check_training_script_save_pattern(script: str) -> str | None: """Check if a training script properly saves models.""" has_from_pretrained = "from_pretrained" in script has_push_to_hub = "push_to_hub" in script if has_from_pretrained and not has_push_to_hub: return "\n\033[91mWARNING: No model save detected in this script. Ensure this is intentional.\033[0m" elif has_from_pretrained and has_push_to_hub: return "\n\033[92mModel will be pushed to hub after training.\033[0m" return None ================================================ FILE: agent/utils/terminal_display.py ================================================ """ Terminal display utilities — rich-powered CLI formatting. """ import re from rich.console import Console from rich.markdown import Heading, Markdown from rich.panel import Panel from rich.theme import Theme class _LeftHeading(Heading): """Rich's default Markdown renders h1/h2 centered via Align.center. Yield the styled text directly so headings stay left-aligned.""" def __rich_console__(self, console, options): self.text.justify = "left" yield self.text Markdown.elements["heading_open"] = _LeftHeading _ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") def _clip_to_width(s: str, width: int) -> str: """Truncate a string to `width` visible columns, preserving ANSI styles. Needed for the sub-agent live redraw: cursor-up-and-erase assumes one logical line == one terminal row. If a line wraps, cursor-up undershoots and the next redraw corrupts the display. Truncating prevents wrap. """ if width <= 0: return s out: list[str] = [] visible = 0 i = 0 # Reserve 1 char for the trailing ellipsis limit = width - 1 truncated = False while i < len(s): m = _ANSI_RE.match(s, i) if m: out.append(m.group()) i = m.end() continue if visible >= limit: truncated = True break out.append(s[i]) visible += 1 i += 1 if truncated: # Strip styles (so ellipsis isn't left hanging inside a style run) out.append("\033[0m…") return "".join(out) _THEME = Theme({ "tool.name": "bold rgb(255,200,80)", "tool.args": "dim", "tool.ok": "dim green", "tool.fail": "dim red", "info": "dim", "muted": "dim", # Markdown emphasis colors "markdown.strong": "bold rgb(255,200,80)", "markdown.emphasis": "italic rgb(180,140,40)", "markdown.code": "rgb(120,220,255)", "markdown.code_block": "rgb(120,220,255)", "markdown.link": "underline rgb(90,180,255)", "markdown.h1": "bold rgb(255,200,80)", "markdown.h2": "bold rgb(240,180,95)", "markdown.h3": "bold rgb(220,165,100)", }) _console = Console(theme=_THEME, highlight=False) # Indent prefix for all agent output (aligns under the `>` prompt) _I = " " def get_console() -> Console: return _console # ── Banner ───────────────────────────────────────────────────────────── def print_banner(model: str | None = None, hf_user: str | None = None) -> None: """Print particle logo then CRT boot sequence with system info.""" from agent.utils.particle_logo import run_particle_logo from agent.utils.crt_boot import run_boot_sequence # Particle coalesce logo — 1.5s converge, 2s hold run_particle_logo(_console, hold_seconds=2.0) # Clear screen for CRT boot — starts from top _console.file.write("\033[2J\033[H") _console.file.flush() model_label = model or "bedrock/us.anthropic.claude-opus-4-6-v1" user_label = hf_user or "not logged in" # Warm gold palette matching the shimmer highlight (255, 200, 80) gold = "rgb(255,200,80)" dim_gold = "rgb(180,140,40)" boot_lines = [ (f"{_I}Initializing agent runtime...", gold), (f"{_I} User: {user_label}", dim_gold), (f"{_I} Model: {model_label}", dim_gold), (f"{_I} Tools: loading...", dim_gold), ("", ""), (f"{_I}/help for commands · /model to switch · /quit to exit", gold), ] run_boot_sequence(_console, boot_lines) # ── Init progress ────────────────────────────────────────────────────── def print_init_done(tool_count: int = 0) -> None: import time f = _console.file # Overwrite the "Tools: loading..." line with actual count f.write(f"\033[A\033[A\033[A\033[K") # Move up 3 lines (blank + help + blank) then up to tools line f.write(f"\033[A\033[K") gold = "\033[38;2;180;140;40m" reset = "\033[0m" tool_text = f"{_I} Tools: {tool_count} loaded" for ch in tool_text: f.write(f"{gold}{ch}{reset}") f.flush() time.sleep(0.012) f.write("\n\n") # Reprint the help line f.write(f"{_I}\033[38;2;255;200;80m/help for commands · /model to switch · /quit to exit{reset}\n\n") # Ready message — minimal padding f.write(f"{_I}\033[38;2;255;200;80mReady. Let's build something impressive.{reset}\n") f.flush() # ── Tool calls ───────────────────────────────────────────────────────── def print_tool_call(tool_name: str, args_preview: str) -> None: import time f = _console.file # CRT-style: type out tool name in HF yellow gold = "\033[38;2;255;200;80m" reset = "\033[0m" f.write(f"{_I}{gold}▸ ") for ch in tool_name: f.write(ch) f.flush() time.sleep(0.015) f.write(f"{reset} \033[2m{args_preview}{reset}\n") f.flush() def print_tool_output(output: str, success: bool, truncate: bool = True) -> None: if truncate: output = _truncate(output, max_lines=10) style = "tool.ok" if success else "tool.fail" # Indent each line of tool output indented = "\n".join(f"{_I} {line}" for line in output.split("\n")) _console.print(f"[{style}]{indented}[/{style}]") class SubAgentDisplayManager: """Manages multiple concurrent sub-agent displays. Each agent gets its own stats and rolling tool-call log. All agents are rendered together so terminal escape-code erase/redraw stays consistent. """ _MAX_VISIBLE = 4 # tool-call lines shown per agent def __init__(self): self._agents: dict[str, dict] = {} # agent_id -> state dict self._lines_on_screen = 0 self._ticker_task = None def start(self, agent_id: str, label: str = "research") -> None: import asyncio import time self._agents[agent_id] = { "label": label, "calls": [], "tool_count": 0, "token_count": 0, "start_time": time.monotonic(), } if not self._ticker_task: self._ticker_task = asyncio.ensure_future(self._tick()) self._redraw() def set_tokens(self, agent_id: str, tokens: int) -> None: if agent_id in self._agents: self._agents[agent_id]["token_count"] = tokens def set_tool_count(self, agent_id: str, count: int) -> None: if agent_id in self._agents: self._agents[agent_id]["tool_count"] = count def add_call(self, agent_id: str, tool_desc: str) -> None: if agent_id in self._agents: self._agents[agent_id]["calls"].append(tool_desc) self._redraw() def clear(self, agent_id: str) -> None: # On completion: erase the live region, freeze a single-line summary # for this agent ("✓ research: … (stats)") above the live region so # the user sees each sub-agent finish cleanly without the tool-call # noise, then redraw remaining live agents. agent = self._agents.pop(agent_id, None) self._erase() if agent is not None: width = max(10, _console.width) line = _clip_to_width(self._render_completion_line(agent), width) _console.file.write(line + "\n") _console.file.flush() self._lines_on_screen = 0 if not self._agents: if self._ticker_task: self._ticker_task.cancel() self._ticker_task = None else: self._redraw() @staticmethod def _render_completion_line(agent: dict) -> str: stats = SubAgentDisplayManager._format_stats(agent) label = agent["label"] # dim green check + dim label; stats in parens line = f"{_I}\033[38;2;120;200;140m✓\033[0m \033[2m{label}\033[0m" if stats: line += f" \033[2m({stats})\033[0m" return line async def _tick(self) -> None: import asyncio try: while True: await asyncio.sleep(1.0) if self._agents: self._redraw() except asyncio.CancelledError: pass @staticmethod def _format_stats(agent: dict) -> str: import time start = agent["start_time"] if start is None: return "" elapsed = time.monotonic() - start if elapsed < 60: time_str = f"{elapsed:.0f}s" else: time_str = f"{elapsed / 60:.0f}m {elapsed % 60:.0f}s" tok = agent["token_count"] tok_str = f"{tok / 1000:.1f}k" if tok >= 1000 else str(tok) return f"{agent['tool_count']} tool uses · {tok_str} tokens · {time_str}" def _erase(self) -> None: if self._lines_on_screen > 0: f = _console.file for _ in range(self._lines_on_screen): f.write("\033[A\033[K") f.flush() def _render_agent_lines(self, agent: dict, compact: bool = False) -> list[str]: """Render one agent's block. compact=True → single line (label + stats + most-recent tool name); compact=False → header + up to _MAX_VISIBLE rolling tool-call lines. We use compact mode when multiple agents are live so the total live region stays small enough to fit on one screen. Otherwise cursor-up can't reach lines that have scrolled into scrollback, and every redraw pollutes history with a stale copy. """ stats = self._format_stats(agent) label = agent["label"] header = f"{_I}\033[38;2;255;200;80m▸ {label}\033[0m" if stats: header += f" \033[2m({stats})\033[0m" if compact: latest = agent["calls"][-1] if agent["calls"] else "" if latest: # Strip long json tails for the inline view short = latest.split(" ")[0] if " " in latest else latest header += f" \033[2m·\033[0m \033[2m{short}\033[0m" return [header] lines = [header] visible = agent["calls"][-self._MAX_VISIBLE:] for desc in visible: lines.append(f"{_I} \033[2m{desc}\033[0m") return lines def _redraw(self) -> None: f = _console.file self._erase() compact = len(self._agents) > 1 width = max(10, _console.width) lines: list[str] = [] for agent in self._agents.values(): for ln in self._render_agent_lines(agent, compact=compact): lines.append(_clip_to_width(ln, width)) for line in lines: f.write(line + "\n") f.flush() self._lines_on_screen = len(lines) _subagent_display = SubAgentDisplayManager() def print_tool_log(tool: str, log: str, agent_id: str = "", label: str = "") -> None: """Handle tool log events — sub-agent calls get the rolling display.""" if tool == "research": aid = agent_id or "research" if log == "Starting research sub-agent...": _subagent_display.start(aid, label or "research") elif log == "Research complete.": _subagent_display.clear(aid) elif log.startswith("tokens:"): _subagent_display.set_tokens(aid, int(log[7:])) elif log.startswith("tools:"): _subagent_display.set_tool_count(aid, int(log[6:])) else: _subagent_display.add_call(aid, log) else: _console.print(f"{_I}[dim]{tool}: {log}[/dim]") # ── Messages ─────────────────────────────────────────────────────────── async def print_markdown( text: str, cancel_event: "asyncio.Event | None" = None, instant: bool = False, ) -> None: import asyncio import io, random from rich.padding import Padding _console.print() # Render markdown to a string buffer so we can type it out buf = io.StringIO() # Important: StringIO is not a TTY, so Rich would normally strip styles. # Force terminal rendering so ANSI style codes are preserved for typewriter output. buf_console = Console( file=buf, width=_console.width, highlight=False, theme=_THEME, force_terminal=True, color_system=_console.color_system or "truecolor", ) buf_console.print(Padding(Markdown(text), (0, 0, 0, 2))) rendered = buf.getvalue() # Strip trailing whitespace from each line so we don't type across the full width lines = rendered.split("\n") rendered = "\n".join(line.rstrip() for line in lines) f = _console.file # Headless / non-interactive: dump the rendered markdown in one write. if instant: f.write(rendered) f.write("\n") f.flush() return # CRT typewriter effect — async so the event loop can service signal # handlers (Ctrl+C during streaming) between characters. If cancelled # mid-type, stop cleanly: write an ANSI reset so half-open color state # doesn't bleed onto the "interrupted" line, and return. rng = random.Random(42) cancelled = False for ch in rendered: if cancel_event is not None and cancel_event.is_set(): cancelled = True break f.write(ch) f.flush() if ch == "\n": await asyncio.sleep(0.002) elif ch == " ": await asyncio.sleep(0.002) elif rng.random() < 0.03: await asyncio.sleep(0.015) else: await asyncio.sleep(0.004) f.write("\033[0m\n" if cancelled else "\n") f.flush() def print_error(message: str) -> None: _console.print(f"\n{_I}[bold red]Error:[/bold red] {message}") def print_turn_complete() -> None: pass # no separator — clean output def print_interrupted() -> None: _console.print(f"\n{_I}[dim italic]interrupted[/dim italic]") def print_compacted(old_tokens: int, new_tokens: int) -> None: _console.print(f"{_I}[dim]context compacted: {old_tokens:,} → {new_tokens:,} tokens[/dim]") # ── Approval ─────────────────────────────────────────────────────────── def print_approval_header(count: int) -> None: label = f"Approval required — {count} item{'s' if count != 1 else ''}" _console.print() _console.print(f"{_I}", Panel(f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False)) def print_approval_item(index: int, total: int, tool_name: str, operation: str) -> None: _console.print(f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}") def print_yolo_approve(count: int) -> None: _console.print(f"{_I}[bold yellow]yolo →[/bold yellow] auto-approved {count} item(s)") # ── Help ─────────────────────────────────────────────────────────────── HELP_TEXT = f"""\ {_I}[bold]Commands[/bold] {_I} [cyan]/help[/cyan] Show this help {_I} [cyan]/undo[/cyan] Undo last turn {_I} [cyan]/compact[/cyan] Compact context window {_I} [cyan]/model[/cyan] [id] Show available models or switch {_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off) {_I} [cyan]/yolo[/cyan] Toggle auto-approve mode {_I} [cyan]/status[/cyan] Current model & turn count {_I} [cyan]/quit[/cyan] Exit""" def print_help() -> None: _console.print() _console.print(HELP_TEXT) _console.print() # ── Plan display ─────────────────────────────────────────────────────── def format_plan_display() -> str: """Format the current plan for display.""" from agent.tools.plan_tool import get_current_plan plan = get_current_plan() if not plan: return "" completed = [t for t in plan if t["status"] == "completed"] in_progress = [t for t in plan if t["status"] == "in_progress"] pending = [t for t in plan if t["status"] == "pending"] lines = [] for t in completed: lines.append(f"{_I}[green]✓[/green] [dim]{t['content']}[/dim]") for t in in_progress: lines.append(f"{_I}[yellow]▸[/yellow] {t['content']}") for t in pending: lines.append(f"{_I}[dim]○ {t['content']}[/dim]") summary = f"[dim]{len(completed)}/{len(plan)} done[/dim]" lines.append(f"{_I}{summary}") return "\n".join(lines) def print_plan() -> None: plan_str = format_plan_display() if plan_str: _console.print(plan_str) # ── Formatting for plan_tool output (used by plan_tool handler) ──────── def format_plan_tool_output(todos: list) -> str: if not todos: return "Plan is empty." lines = ["Plan updated:", ""] completed = [t for t in todos if t["status"] == "completed"] in_progress = [t for t in todos if t["status"] == "in_progress"] pending = [t for t in todos if t["status"] == "pending"] for t in completed: lines.append(f" [x] {t['id']}. {t['content']}") for t in in_progress: lines.append(f" [~] {t['id']}. {t['content']}") for t in pending: lines.append(f" [ ] {t['id']}. {t['content']}") lines.append(f"\n{len(completed)}/{len(todos)} done") return "\n".join(lines) # ── Internal helpers ─────────────────────────────────────────────────── def _truncate(text: str, max_lines: int = 6) -> str: lines = text.split("\n") if len(lines) <= max_lines: return text return "\n".join(lines[:max_lines]) + f"\n... ({len(lines) - max_lines} more lines)" ================================================ FILE: backend/__init__.py ================================================ # Backend package for HF Agent web interface ================================================ FILE: backend/dependencies.py ================================================ """Authentication dependencies for FastAPI routes. - In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user. - In production: validates Bearer tokens or cookies against HF OAuth. """ import logging import os import time from typing import Any import httpx from fastapi import HTTPException, Request, status logger = logging.getLogger(__name__) OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", "")) HF_EMPLOYEE_ORG = os.environ.get("HF_EMPLOYEE_ORG", "huggingface") # Simple in-memory token cache: token -> (user_info, expiry_time) _token_cache: dict[str, tuple[dict[str, Any], float]] = {} TOKEN_CACHE_TTL = 300 # 5 minutes # Org membership cache: key -> expiry_time (only caches positive results) _org_member_cache: dict[str, float] = {} DEV_USER: dict[str, Any] = { "user_id": "dev", "username": "dev", "authenticated": True, "plan": "org", # Dev runs at the Pro/Org quota tier so local testing isn't capped. } # Plan field discovery — log the whoami-v2 shape once at DEBUG so we can # confirm the actual key in production without hammering the HF API. _WHOAMI_SHAPE_LOGGED = False async def _validate_token(token: str) -> dict[str, Any] | None: """Validate a token against HF OAuth userinfo endpoint. Results are cached for TOKEN_CACHE_TTL seconds to avoid excessive API calls. """ now = time.time() # Check cache if token in _token_cache: user_info, expiry = _token_cache[token] if now < expiry: return user_info del _token_cache[token] # Validate against HF async with httpx.AsyncClient(timeout=10.0) as client: try: response = await client.get( f"{OPENID_PROVIDER_URL}/oauth/userinfo", headers={"Authorization": f"Bearer {token}"}, ) if response.status_code != 200: logger.debug("Token validation failed: status %d", response.status_code) return None user_info = response.json() _token_cache[token] = (user_info, now + TOKEN_CACHE_TTL) return user_info except httpx.HTTPError as e: logger.warning("Token validation error: %s", e) return None def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]: """Build a normalized user dict from HF userinfo response.""" return { "user_id": user_info.get("sub", user_info.get("preferred_username", "unknown")), "username": user_info.get("preferred_username", "unknown"), "name": user_info.get("name"), "picture": user_info.get("picture"), "authenticated": True, } def _normalize_plan(whoami: dict[str, Any]) -> str: """Map an HF /api/whoami-v2 payload to one of: 'free' | 'pro' | 'org'. The exact field shape in whoami-v2 isn't documented for our purposes, so we try a handful of likely keys and fall back to 'free'. The first call logs the raw shape at DEBUG (see `_fetch_user_plan`) so we can pin the real key post-deploy. """ plan_str = "" for key in ("plan", "type", "accountType"): val = whoami.get(key) if isinstance(val, str) and val: plan_str = val.lower() break if not plan_str: if whoami.get("isPro") is True or whoami.get("is_pro") is True: return "pro" if "pro" in plan_str or "enterprise" in plan_str or "team" in plan_str: return "pro" # Org tier: anyone in a paid / enterprise org. We don't pay for this # right now, but the "pro" cap applies identically. orgs = whoami.get("orgs") or [] if isinstance(orgs, list): for org in orgs: if isinstance(org, dict): org_plan = str(org.get("plan") or org.get("type") or "").lower() if "pro" in org_plan or "enterprise" in org_plan or "team" in org_plan: return "org" return "free" async def _fetch_user_plan(token: str) -> str: """Look up the user's HF plan via /api/whoami-v2. Returns 'free' | 'pro' | 'org'. Non-200, network errors, or an unknown payload shape all collapse to 'free' — safe default; we'd rather under- grant the Pro cap than over-grant it on bad data. """ global _WHOAMI_SHAPE_LOGGED async with httpx.AsyncClient(timeout=5.0) as client: try: resp = await client.get( f"{OPENID_PROVIDER_URL}/api/whoami-v2", headers={"Authorization": f"Bearer {token}"}, ) if resp.status_code != 200: return "free" whoami = resp.json() except httpx.HTTPError: return "free" except ValueError: return "free" if not _WHOAMI_SHAPE_LOGGED: _WHOAMI_SHAPE_LOGGED = True logger.debug( "whoami-v2 payload keys: %s (sample values: plan=%r type=%r isPro=%r)", sorted(whoami.keys()) if isinstance(whoami, dict) else type(whoami).__name__, whoami.get("plan") if isinstance(whoami, dict) else None, whoami.get("type") if isinstance(whoami, dict) else None, whoami.get("isPro") if isinstance(whoami, dict) else None, ) if not isinstance(whoami, dict): return "free" return _normalize_plan(whoami) async def _extract_user_from_token(token: str) -> dict[str, Any] | None: """Validate a token and return a user dict, or None.""" user_info = await _validate_token(token) if user_info is None: return None user = _user_from_info(user_info) user["plan"] = await _fetch_user_plan(token) return user async def check_org_membership(token: str, org_name: str) -> bool: """Check if the token owner belongs to an HF org. Only caches positive results.""" now = time.time() key = token + org_name cached = _org_member_cache.get(key) if cached and cached > now: return True async with httpx.AsyncClient(timeout=10.0) as client: try: resp = await client.get( f"{OPENID_PROVIDER_URL}/api/whoami-v2", headers={"Authorization": f"Bearer {token}"}, ) if resp.status_code != 200: return False orgs = {o.get("name") for o in resp.json().get("orgs", [])} if org_name in orgs: _org_member_cache[key] = now + TOKEN_CACHE_TTL return True return False except httpx.HTTPError: return False async def get_current_user(request: Request) -> dict[str, Any]: """FastAPI dependency: extract and validate the current user. Checks (in order): 1. Authorization: Bearer header 2. hf_access_token cookie In dev mode (AUTH_ENABLED=False), returns a default dev user. """ if not AUTH_ENABLED: return DEV_USER # Try Authorization header auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] user = await _extract_user_from_token(token) if user: return user # Try cookie token = request.cookies.get("hf_access_token") if token: user = await _extract_user_from_token(token) if user: return user raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated. Please log in via /auth/login.", headers={"WWW-Authenticate": "Bearer"}, ) def _extract_token(request: Request) -> str | None: """Pull the HF access token from the Authorization header or cookie. Mirrors the lookup order used by ``get_current_user``. """ auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): return auth_header[7:] return request.cookies.get("hf_access_token") async def require_huggingface_org_member(request: Request) -> bool: """Return True if the caller is a member of the ``huggingface`` org. Used to gate endpoints that can push a session onto an Anthropic model billed to the Space's ``ANTHROPIC_API_KEY``. Returns True unconditionally in dev mode so local testing isn't blocked. """ if not AUTH_ENABLED: return True token = _extract_token(request) if not token: return False return await check_org_membership(token, HF_EMPLOYEE_ORG) ================================================ FILE: backend/main.py ================================================ """FastAPI application for HF Agent web interface.""" import logging import os from contextlib import asynccontextmanager from pathlib import Path from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from routes.agent import router as agent_router from routes.auth import router as auth_router # Load .env from project root (parent directory) load_dotenv(Path(__file__).parent.parent / ".env") # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan handler.""" logger.info("Starting HF Agent backend...") yield logger.info("Shutting down HF Agent backend...") app = FastAPI( title="HF Agent", description="ML Engineering Assistant API", version="1.0.0", lifespan=lifespan, ) # CORS middleware for development app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:5173", # Vite dev server "http://localhost:3000", "http://127.0.0.1:5173", "http://127.0.0.1:3000", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Include routers app.include_router(agent_router) app.include_router(auth_router) # Serve static files (frontend build) in production static_path = Path(__file__).parent.parent / "static" if static_path.exists(): app.mount("/", StaticFiles(directory=str(static_path), html=True), name="static") logger.info(f"Serving static files from {static_path}") else: logger.info("No static directory found, running in API-only mode") @app.get("/api") async def api_root(): """API root endpoint.""" return { "name": "HF Agent API", "version": "1.0.0", "docs": "/docs", } if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port) ================================================ FILE: backend/models.py ================================================ """Pydantic models for API requests and responses.""" from enum import Enum from typing import Any from pydantic import BaseModel class OpType(str, Enum): """Operation types matching agent/core/agent_loop.py.""" USER_INPUT = "user_input" EXEC_APPROVAL = "exec_approval" INTERRUPT = "interrupt" UNDO = "undo" COMPACT = "compact" SHUTDOWN = "shutdown" class Operation(BaseModel): """Operation to be submitted to the agent.""" op_type: OpType data: dict[str, Any] | None = None class Submission(BaseModel): """Submission wrapper with ID and operation.""" id: str operation: Operation class ToolApproval(BaseModel): """Approval decision for a single tool call.""" tool_call_id: str approved: bool feedback: str | None = None edited_script: str | None = None class ApprovalRequest(BaseModel): """Request to approve/reject tool calls.""" session_id: str approvals: list[ToolApproval] class SubmitRequest(BaseModel): """Request to submit user input.""" session_id: str text: str class TruncateRequest(BaseModel): """Request to truncate conversation history to before a specific user message.""" user_message_index: int class SessionResponse(BaseModel): """Response when creating a new session.""" session_id: str ready: bool = True class PendingApprovalTool(BaseModel): """A tool waiting for user approval.""" tool: str tool_call_id: str arguments: dict[str, Any] = {} class SessionInfo(BaseModel): """Session metadata.""" session_id: str created_at: str is_active: bool is_processing: bool = False message_count: int user_id: str = "dev" pending_approval: list[PendingApprovalTool] | None = None model: str | None = None class HealthResponse(BaseModel): """Health check response.""" status: str = "ok" active_sessions: int = 0 max_sessions: int = 0 class LLMHealthResponse(BaseModel): """LLM provider health check response.""" status: str # "ok" | "error" model: str error: str | None = None error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown" ================================================ FILE: backend/routes/__init__.py ================================================ # Routes package ================================================ FILE: backend/routes/agent.py ================================================ """Agent API routes — REST + SSE endpoints. All routes (except /health) require authentication via the get_current_user dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically. """ import asyncio import json import logging import os from typing import Any from dependencies import get_current_user, require_huggingface_org_member from fastapi import ( APIRouter, Depends, HTTPException, Request, ) from fastapi.responses import StreamingResponse from litellm import acompletion from models import ( ApprovalRequest, HealthResponse, LLMHealthResponse, SessionInfo, SessionResponse, SubmitRequest, TruncateRequest, ) from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, session_manager import user_quotas from agent.core.llm_params import _resolve_llm_params logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) AVAILABLE_MODELS = [ { "id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6", "provider": "huggingface", "tier": "free", "recommended": True, }, { "id": "bedrock/us.anthropic.claude-opus-4-6-v1", "label": "Claude Opus 4.6", "provider": "anthropic", "tier": "pro", "recommended": True, }, { "id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7", "provider": "huggingface", "tier": "free", }, { "id": "zai-org/GLM-5.1", "label": "GLM 5.1", "provider": "huggingface", "tier": "free", }, ] def _is_anthropic_model(model_id: str) -> bool: return "anthropic" in model_id async def _require_hf_for_anthropic(request: Request, model_id: str) -> None: """403 if a non-``huggingface``-org user tries to select an Anthropic model. Anthropic models are billed to the Space's ``ANTHROPIC_API_KEY``; every other model in ``AVAILABLE_MODELS`` is routed through HF Router and billed via ``X-HF-Bill-To``. The gate only fires for Anthropic so non-HF users can still freely switch between the free models. Pattern: https://github.com/huggingface/ml-intern/pull/63 """ if not _is_anthropic_model(model_id): return if not await require_huggingface_org_member(request): raise HTTPException( status_code=403, detail={ "error": "anthropic_restricted", "message": ( "Opus is gated to HF staff. Pick a free model — " "Kimi K2.6, MiniMax M2.7, or GLM 5.1 — instead." ), }, ) async def _enforce_claude_quota( user: dict[str, Any], agent_session: AgentSession, ) -> None: """Charge the user's daily Claude quota on first use of Anthropic in a session. Runs at *message-submit* time, not session-create time — so spinning up a Claude session to look around doesn't burn quota. The ``claude_counted`` flag on ``AgentSession`` guards against re-counting the same session. No-ops when the session's current model isn't Anthropic, or when this session has already been charged. Raises 429 when the user has hit their daily cap. """ if agent_session.claude_counted: return model_name = agent_session.session.config.model_name if not _is_anthropic_model(model_name): return user_id = user["user_id"] used = await user_quotas.get_claude_used_today(user_id) cap = user_quotas.daily_cap_for(user.get("plan")) if used >= cap: raise HTTPException( status_code=429, detail={ "error": "claude_daily_cap", "plan": user.get("plan", "free"), "cap": cap, "message": ( "Daily Claude limit reached. Upgrade to HF Pro for " f"{user_quotas.CLAUDE_PRO_DAILY}/day or use a free model." ), }, ) await user_quotas.increment_claude(user_id) agent_session.claude_counted = True def _check_session_access(session_id: str, user: dict[str, Any]) -> None: """Verify the user has access to the given session. Raises 403 or 404.""" info = session_manager.get_session_info(session_id) if not info: raise HTTPException(status_code=404, detail="Session not found") if not session_manager.verify_session_access(session_id, user["user_id"]): raise HTTPException(status_code=403, detail="Access denied to this session") @router.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: """Health check endpoint.""" return HealthResponse( status="ok", active_sessions=session_manager.active_session_count, max_sessions=MAX_SESSIONS, ) @router.get("/health/llm", response_model=LLMHealthResponse) async def llm_health_check() -> LLMHealthResponse: """Check if the LLM provider is reachable and the API key is valid. Makes a minimal 1-token completion call. Catches common errors: - 401 → invalid API key - 402/insufficient_quota → out of credits - 429 → rate limited - timeout / network → provider unreachable """ model = session_manager.config.model_name try: llm_params = _resolve_llm_params(model, reasoning_effort="high") await acompletion( messages=[{"role": "user", "content": "hi"}], max_tokens=1, timeout=10, **llm_params, ) return LLMHealthResponse(status="ok", model=model) except Exception as e: err_str = str(e).lower() error_type = "unknown" if ( "401" in err_str or "auth" in err_str or "invalid" in err_str or "api key" in err_str ): error_type = "auth" elif ( "402" in err_str or "credit" in err_str or "quota" in err_str or "insufficient" in err_str or "billing" in err_str ): error_type = "credits" elif "429" in err_str or "rate" in err_str: error_type = "rate_limit" elif "timeout" in err_str or "connect" in err_str or "network" in err_str: error_type = "network" logger.warning(f"LLM health check failed ({error_type}): {e}") return LLMHealthResponse( status="error", model=model, error=str(e)[:500], error_type=error_type, ) @router.get("/config/model") async def get_model() -> dict: """Get current model and available models. No auth required.""" return { "current": session_manager.config.model_name, "available": AVAILABLE_MODELS, } _TITLE_STRIP_CHARS = str.maketrans("", "", "`*_~#[]()") @router.post("/title") async def generate_title( request: SubmitRequest, user: dict = Depends(get_current_user) ) -> dict: """Generate a short title for a chat session based on the first user message. Always uses gpt-oss-120b via Cerebras on the HF router. The tab headline renders as plain text, so the model is told to avoid markdown and any stray formatting characters are stripped before returning. gpt-oss is a reasoning model — reasoning_effort=low keeps the reasoning budget small so the 60-token output budget isn't consumed before the title is written. """ api_key = ( os.environ.get("INFERENCE_TOKEN") or (user.get("hf_token") if isinstance(user, dict) else None) or os.environ.get("HF_TOKEN") ) try: response = await acompletion( # Double openai/ prefix: LiteLLM strips the first as its provider # prefix, leaving the HF model id on the wire for the router. model="openai/openai/gpt-oss-120b:cerebras", api_base="https://router.huggingface.co/v1", api_key=api_key, messages=[ { "role": "system", "content": ( "Generate a very short title (max 6 words) for a chat conversation " "that starts with the following user message. " "Reply with ONLY the title in plain text. " "Do NOT use markdown, backticks, asterisks, quotes, brackets, or any " "formatting characters. No punctuation at the end." ), }, {"role": "user", "content": request.text[:500]}, ], max_tokens=60, temperature=0.3, timeout=10, reasoning_effort="low", ) title = response.choices[0].message.content.strip().strip('"').strip("'") title = title.translate(_TITLE_STRIP_CHARS).strip() if len(title) > 50: title = title[:50].rstrip() + "…" return {"title": title} except Exception as e: logger.warning(f"Title generation failed: {e}") fallback = request.text.strip() title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback return {"title": title} @router.post("/session", response_model=SessionResponse) async def create_session( request: Request, user: dict = Depends(get_current_user) ) -> SessionResponse: """Create a new agent session bound to the authenticated user. The user's HF access token is extracted from the Authorization header and stored in the session so that tools (e.g. hf_jobs) can act on behalf of the user. Optional body ``{"model"?: }`` selects the session's LLM; unknown ids are rejected (400). The Claude-quota gate runs at message-submit time, not here — spinning up an Opus session to look around is free. Returns 503 if the server or user has reached the session limit. """ # Extract the user's HF token (Bearer header, HttpOnly cookie, or env var) hf_token = None auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): hf_token = auth_header[7:] if not hf_token: hf_token = request.cookies.get("hf_access_token") if not hf_token: hf_token = os.environ.get("HF_TOKEN") # Optional model override. Empty body falls back to the config default. model: str | None = None try: body = await request.json() except Exception: body = None if isinstance(body, dict): model = body.get("model") valid_ids = {m["id"] for m in AVAILABLE_MODELS} if model and model not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model}") # Opus is gated to HF staff (PR #63). Only fires when the resolved model # is Anthropic; free models pass through. resolved_model = model or session_manager.config.model_name await _require_hf_for_anthropic(request, resolved_model) try: session_id = await session_manager.create_session( user_id=user["user_id"], hf_token=hf_token, model=model ) except SessionCapacityError as e: raise HTTPException(status_code=503, detail=str(e)) return SessionResponse(session_id=session_id, ready=True) @router.post("/session/restore-summary", response_model=SessionResponse) async def restore_session_summary( request: Request, body: dict, user: dict = Depends(get_current_user) ) -> SessionResponse: """Create a new session seeded with a summary of the caller's prior conversation. The client sends its cached messages; we run the standard summarization prompt on them and drop the result into the new session's context as a user-role system note. Optional ``"model"`` in the body overrides the session's LLM. The Claude-quota gate runs at message-submit time, not here. """ messages = body.get("messages") if not isinstance(messages, list) or not messages: raise HTTPException(status_code=400, detail="Missing 'messages' array") hf_token = None auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): hf_token = auth_header[7:] if not hf_token: hf_token = request.cookies.get("hf_access_token") if not hf_token: hf_token = os.environ.get("HF_TOKEN") model = body.get("model") valid_ids = {m["id"] for m in AVAILABLE_MODELS} if model and model not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model}") resolved_model = model or session_manager.config.model_name await _require_hf_for_anthropic(request, resolved_model) try: session_id = await session_manager.create_session( user_id=user["user_id"], hf_token=hf_token, model=model ) except SessionCapacityError as e: raise HTTPException(status_code=503, detail=str(e)) try: summarized = await session_manager.seed_from_summary(session_id, messages) except ValueError as e: raise HTTPException(status_code=500, detail=str(e)) except Exception as e: logger.exception("seed_from_summary failed") raise HTTPException(status_code=500, detail=f"Summary failed: {e}") logger.info( f"Seeded session {session_id} for {user.get('username', 'unknown')} " f"(summary of {summarized} messages)" ) return SessionResponse(session_id=session_id, ready=True) @router.get("/session/{session_id}", response_model=SessionInfo) async def get_session( session_id: str, user: dict = Depends(get_current_user) ) -> SessionInfo: """Get session information. Only accessible by the session owner.""" _check_session_access(session_id, user) info = session_manager.get_session_info(session_id) return SessionInfo(**info) @router.post("/session/{session_id}/model") async def set_session_model( session_id: str, body: dict, request: Request, user: dict = Depends(get_current_user), ) -> dict: """Switch the active model for a single session (tab-scoped). Takes effect on the next LLM call in that session — other sessions (including other browser tabs) are unaffected. Model switches don't charge quota — the Claude-quota gate only fires at message-submit time. Switching TO an Anthropic model requires HF org membership (PR #63); free-model switches are unrestricted. """ _check_session_access(session_id, user) model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") valid_ids = {m["id"] for m in AVAILABLE_MODELS} if model_id not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") await _require_hf_for_anthropic(request, model_id) agent_session = session_manager.sessions.get(session_id) if not agent_session: raise HTTPException(status_code=404, detail="Session not found") agent_session.session.update_model(model_id) logger.info( f"Session {session_id} model → {model_id} " f"(by {user.get('username', 'unknown')})" ) return {"session_id": session_id, "model": model_id} @router.get("/user/quota") async def get_user_quota(user: dict = Depends(get_current_user)) -> dict: """Return the user's plan tier and today's Claude-session quota state.""" plan = user.get("plan", "free") used = await user_quotas.get_claude_used_today(user["user_id"]) cap = user_quotas.daily_cap_for(plan) return { "plan": plan, "claude_used_today": used, "claude_daily_cap": cap, "claude_remaining": max(0, cap - used), } @router.get("/sessions", response_model=list[SessionInfo]) async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]: """List sessions belonging to the authenticated user.""" sessions = session_manager.list_sessions(user_id=user["user_id"]) return [SessionInfo(**s) for s in sessions] @router.delete("/session/{session_id}") async def delete_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Delete a session. Only accessible by the session owner.""" _check_session_access(session_id, user) success = await session_manager.delete_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") return {"status": "deleted", "session_id": session_id} @router.post("/submit") async def submit_input( request: SubmitRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit user input to a session. Only accessible by the session owner.""" _check_session_access(request.session_id, user) agent_session = session_manager.sessions.get(request.session_id) if agent_session is not None: await _enforce_claude_quota(user, agent_session) success = await session_manager.submit_user_input(request.session_id, request.text) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "submitted", "session_id": request.session_id} @router.post("/approve") async def submit_approval( request: ApprovalRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit tool approvals to a session. Only accessible by the session owner.""" _check_session_access(request.session_id, user) approvals = [ { "tool_call_id": a.tool_call_id, "approved": a.approved, "feedback": a.feedback, "edited_script": a.edited_script, } for a in request.approvals ] success = await session_manager.submit_approval(request.session_id, approvals) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "submitted", "session_id": request.session_id} @router.post("/chat/{session_id}") async def chat_sse( session_id: str, request: Request, user: dict = Depends(get_current_user), ) -> StreamingResponse: """SSE endpoint: submit input or approval, then stream events until turn ends.""" _check_session_access(session_id, user) agent_session = session_manager.sessions.get(session_id) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") # Parse body body = await request.json() # Subscribe BEFORE submitting so we never miss events — even if the # agent loop processes the submission before this coroutine continues. broadcaster = agent_session.broadcaster sub_id, event_queue = broadcaster.subscribe() # Submit the operation text = body.get("text") approvals = body.get("approvals") # Gate user-message sends against the daily Claude quota. Approvals are # continuations of an in-progress turn — the session was already charged # on its first message, so we skip the gate there. if text is not None and not approvals: try: await _enforce_claude_quota(user, agent_session) except HTTPException: broadcaster.unsubscribe(sub_id) raise try: if approvals: formatted = [ { "tool_call_id": a["tool_call_id"], "approved": a["approved"], "feedback": a.get("feedback"), "edited_script": a.get("edited_script"), } for a in approvals ] success = await session_manager.submit_approval(session_id, formatted) elif text is not None: success = await session_manager.submit_user_input(session_id, text) else: broadcaster.unsubscribe(sub_id) raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") if not success: broadcaster.unsubscribe(sub_id) raise HTTPException(status_code=404, detail="Session not found or inactive") except HTTPException: raise except Exception: broadcaster.unsubscribe(sub_id) raise return _sse_response(broadcaster, event_queue, sub_id) # --------------------------------------------------------------------------- # Shared SSE helpers # --------------------------------------------------------------------------- _TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"} _SSE_KEEPALIVE_SECONDS = 15 def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse: """Build a StreamingResponse that drains *event_queue* as SSE, sending keepalive comments every 15 s to prevent proxy timeouts.""" async def event_generator(): try: while True: try: msg = await asyncio.wait_for( event_queue.get(), timeout=_SSE_KEEPALIVE_SECONDS ) except asyncio.TimeoutError: # SSE comment — ignored by parsers, keeps connection alive yield ": keepalive\n\n" continue event_type = msg.get("event_type", "") yield f"data: {json.dumps(msg)}\n\n" if event_type in _TERMINAL_EVENTS: break finally: broadcaster.unsubscribe(sub_id) return StreamingResponse( event_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.get("/events/{session_id}") async def subscribe_events( session_id: str, user: dict = Depends(get_current_user), ) -> StreamingResponse: """Subscribe to events for a running session without submitting new input. Used by the frontend to re-attach after a connection drop (e.g. screen sleep). Returns 404 if the session isn't active or isn't processing. """ _check_session_access(session_id, user) agent_session = session_manager.sessions.get(session_id) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") broadcaster = agent_session.broadcaster sub_id, event_queue = broadcaster.subscribe() return _sse_response(broadcaster, event_queue, sub_id) @router.post("/interrupt/{session_id}") async def interrupt_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Interrupt the current operation in a session.""" _check_session_access(session_id, user) success = await session_manager.interrupt(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "interrupted", "session_id": session_id} @router.get("/session/{session_id}/messages") async def get_session_messages( session_id: str, user: dict = Depends(get_current_user) ) -> list[dict]: """Return the session's message history from memory.""" _check_session_access(session_id, user) agent_session = session_manager.sessions.get(session_id) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") return [msg.model_dump() for msg in agent_session.session.context_manager.items] @router.post("/undo/{session_id}") async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict: """Undo the last turn in a session.""" _check_session_access(session_id, user) success = await session_manager.undo(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "undo_requested", "session_id": session_id} @router.post("/truncate/{session_id}") async def truncate_session( session_id: str, body: TruncateRequest, user: dict = Depends(get_current_user) ) -> dict: """Truncate conversation to before a specific user message.""" _check_session_access(session_id, user) success = await session_manager.truncate(session_id, body.user_message_index) if not success: raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range") return {"status": "truncated", "session_id": session_id} @router.post("/compact/{session_id}") async def compact_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Compact the context in a session.""" _check_session_access(session_id, user) success = await session_manager.compact(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "compact_requested", "session_id": session_id} @router.post("/shutdown/{session_id}") async def shutdown_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Shutdown a session.""" _check_session_access(session_id, user) success = await session_manager.shutdown_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} ================================================ FILE: backend/routes/auth.py ================================================ """Authentication routes for HF OAuth. Handles the OAuth 2.0 authorization code flow with HF as provider. After successful auth, sets an HttpOnly cookie with the access token. """ import os import secrets import time from urllib.parse import urlencode import httpx from dependencies import AUTH_ENABLED, check_org_membership, get_current_user from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import RedirectResponse router = APIRouter(prefix="/auth", tags=["auth"]) # OAuth configuration from environment OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "") OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "") OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") # In-memory OAuth state store with expiry (5 min TTL) _OAUTH_STATE_TTL = 300 oauth_states: dict[str, dict] = {} def _cleanup_expired_states() -> None: """Remove expired OAuth states to prevent memory growth.""" now = time.time() expired = [k for k, v in oauth_states.items() if now > v.get("expires_at", 0)] for k in expired: del oauth_states[k] def get_redirect_uri(request: Request) -> str: """Get the OAuth callback redirect URI.""" # In HF Spaces, use the SPACE_HOST if available space_host = os.environ.get("SPACE_HOST") if space_host: return f"https://{space_host}/auth/callback" # Otherwise construct from request return str(request.url_for("oauth_callback")) @router.get("/login") async def oauth_login(request: Request) -> RedirectResponse: """Initiate OAuth login flow.""" if not OAUTH_CLIENT_ID: raise HTTPException( status_code=500, detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.", ) # Clean up expired states to prevent memory growth _cleanup_expired_states() # Generate state for CSRF protection state = secrets.token_urlsafe(32) oauth_states[state] = { "redirect_uri": get_redirect_uri(request), "expires_at": time.time() + _OAUTH_STATE_TTL, } # Build authorization URL params = { "client_id": OAUTH_CLIENT_ID, "redirect_uri": get_redirect_uri(request), "scope": "openid profile read-repos write-repos contribute-repos manage-repos inference-api jobs write-discussions", "response_type": "code", "state": state, "orgIds": os.environ.get( "HF_OAUTH_ORG_ID", "698dbf55845d85df163175f1" ), # ml-agent-explorers } auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}" return RedirectResponse(url=auth_url) @router.get("/callback") async def oauth_callback( request: Request, code: str = "", state: str = "" ) -> RedirectResponse: """Handle OAuth callback.""" # Verify state if state not in oauth_states: raise HTTPException(status_code=400, detail="Invalid state parameter") stored_state = oauth_states.pop(state) redirect_uri = stored_state["redirect_uri"] if not code: raise HTTPException(status_code=400, detail="No authorization code provided") # Exchange code for token token_url = f"{OPENID_PROVIDER_URL}/oauth/token" async with httpx.AsyncClient() as client: try: response = await client.post( token_url, data={ "grant_type": "authorization_code", "code": code, "redirect_uri": redirect_uri, "client_id": OAUTH_CLIENT_ID, "client_secret": OAUTH_CLIENT_SECRET, }, ) response.raise_for_status() token_data = response.json() except httpx.HTTPError as e: raise HTTPException(status_code=500, detail=f"Token exchange failed: {e}") # Get user info access_token = token_data.get("access_token") if not access_token: raise HTTPException( status_code=500, detail="Token exchange succeeded but no access_token was returned.", ) # Fetch user info (optional — failure is not fatal) async with httpx.AsyncClient() as client: try: userinfo_response = await client.get( f"{OPENID_PROVIDER_URL}/oauth/userinfo", headers={"Authorization": f"Bearer {access_token}"}, ) userinfo_response.raise_for_status() except httpx.HTTPError: pass # user_info not required for auth flow # Set access token as HttpOnly cookie (not in URL — avoids leaks via # Referrer headers, browser history, and server logs) is_production = bool(os.environ.get("SPACE_HOST")) response = RedirectResponse(url="/", status_code=302) response.set_cookie( key="hf_access_token", value=access_token, httponly=True, secure=is_production, # Secure flag only in production (HTTPS) samesite="lax", max_age=3600 * 24 * 7, # 7 days path="/", ) return response @router.get("/logout") async def logout() -> RedirectResponse: """Log out the user by clearing the auth cookie.""" response = RedirectResponse(url="/") response.delete_cookie(key="hf_access_token", path="/") return response @router.get("/status") async def auth_status() -> dict: """Check if OAuth is enabled on this instance.""" return {"auth_enabled": AUTH_ENABLED} @router.get("/me") async def get_me(user: dict = Depends(get_current_user)) -> dict: """Get current user info. Returns the authenticated user or dev user. Uses the shared auth dependency which handles cookie + Bearer token. """ return user ORG_NAME = "ml-agent-explorers" @router.get("/org-membership") async def org_membership( request: Request, user: dict = Depends(get_current_user) ) -> dict: """Check if the authenticated user belongs to the ml-agent-explorers org.""" if not AUTH_ENABLED: return {"is_member": True} token = request.cookies.get("hf_access_token") or "" if not token: return {"is_member": False} is_member = await check_org_membership(token, ORG_NAME) return {"is_member": is_member} ================================================ FILE: backend/session_manager.py ================================================ """Session manager for handling multiple concurrent agent sessions.""" import asyncio import logging import uuid from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Optional from agent.config import load_config from agent.core.agent_loop import process_submission from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter # Get project root (parent of backend directory) PROJECT_ROOT = Path(__file__).parent.parent DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json") # These dataclasses match agent/main.py structure @dataclass class Operation: """Operation to be executed by the agent.""" op_type: OpType data: Optional[dict[str, Any]] = None @dataclass class Submission: """Submission to the agent loop.""" id: str operation: Operation logger = logging.getLogger(__name__) class EventBroadcaster: """Reads from the agent's event queue and fans out to SSE subscribers. Events that arrive when no subscribers are listening are discarded. With SSE each turn is a separate request, so there is no reconnect scenario that would need buffered replay. """ def __init__(self, event_queue: asyncio.Queue): self._source = event_queue self._subscribers: dict[int, asyncio.Queue] = {} self._counter = 0 def subscribe(self) -> tuple[int, asyncio.Queue]: """Create a new subscriber. Returns (id, queue).""" self._counter += 1 sub_id = self._counter q: asyncio.Queue = asyncio.Queue() self._subscribers[sub_id] = q return sub_id, q def unsubscribe(self, sub_id: int) -> None: self._subscribers.pop(sub_id, None) async def run(self) -> None: """Main loop — reads from source queue and broadcasts.""" while True: try: event: Event = await self._source.get() msg = {"event_type": event.event_type, "data": event.data} for q in self._subscribers.values(): await q.put(msg) except asyncio.CancelledError: break except Exception as e: logger.error(f"EventBroadcaster error: {e}") @dataclass class AgentSession: """Wrapper for an agent session with its associated resources.""" session_id: str session: Session tool_router: ToolRouter submission_queue: asyncio.Queue user_id: str = "dev" # Owner of this session hf_token: str | None = None # User's HF OAuth token for tool execution task: asyncio.Task | None = None created_at: datetime = field(default_factory=datetime.utcnow) is_active: bool = True is_processing: bool = False # True while a submission is being executed broadcaster: Any = None # True once this session has been counted against the user's daily # Claude quota. Guards double-counting when the user re-selects an # Anthropic model mid-session. claude_counted: bool = False class SessionCapacityError(Exception): """Raised when no more sessions can be created.""" def __init__(self, message: str, error_type: str = "global") -> None: super().__init__(message) self.error_type = error_type # "global" or "per_user" # ── Capacity limits ───────────────────────────────────────────────── # Sized for HF Spaces 8 vCPU / 32 GB RAM. # Each session uses ~10-20 MB (context, tools, queues, task); 200 × 20 MB # = 4 GB worst case, leaving plenty of headroom for the Python runtime # and per-request overhead. MAX_SESSIONS: int = 200 MAX_SESSIONS_PER_USER: int = 10 class SessionManager: """Manages multiple concurrent agent sessions.""" def __init__(self, config_path: str | None = None) -> None: self.config = load_config(config_path or DEFAULT_CONFIG_PATH) self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" return sum( 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active ) async def create_session( self, user_id: str = "dev", hf_token: str | None = None, model: str | None = None, ) -> str: """Create a new agent session and return its ID. Session() and ToolRouter() constructors contain blocking I/O (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are executed in a thread pool to avoid freezing the async event loop. Args: user_id: The ID of the user who owns this session. hf_token: The user's HF OAuth token, stored for tool execution. model: Optional model override. When set, replaces ``model_name`` on the per-session config clone. None falls back to the config default. Raises: SessionCapacityError: If the server or user has reached the maximum number of concurrent sessions. """ # ── Capacity checks ────────────────────────────────────────── async with self._lock: active_count = self.active_session_count if active_count >= MAX_SESSIONS: raise SessionCapacityError( f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). " "Please try again later.", error_type="global", ) if user_id != "dev": user_count = self._count_user_sessions(user_id) if user_count >= MAX_SESSIONS_PER_USER: raise SessionCapacityError( f"You have reached the maximum of {MAX_SESSIONS_PER_USER} " "concurrent sessions. Please close an existing session first.", error_type="per_user", ) session_id = str(uuid.uuid4()) # Create queues for this session submission_queue: asyncio.Queue = asyncio.Queue() event_queue: asyncio.Queue = asyncio.Queue() # Run blocking constructors in a thread to keep the event loop responsive. # Without this, Session.__init__ → ContextManager → litellm.get_max_tokens() # blocks all HTTP/SSE handling. import time as _time def _create_session_sync(): t0 = _time.monotonic() tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) # Deep-copy config so each session's model switches independently — # tab A picking GLM doesn't flip tab B off Claude. session_config = self.config.model_copy(deep=True) if model: session_config.model_name = model session = Session( event_queue, config=session_config, tool_router=tool_router, hf_token=hf_token, ) t1 = _time.monotonic() logger.info(f"Session initialized in {t1 - t0:.2f}s") return tool_router, session tool_router, session = await asyncio.to_thread(_create_session_sync) # Create wrapper agent_session = AgentSession( session_id=session_id, session=session, tool_router=tool_router, submission_queue=submission_queue, user_id=user_id, hf_token=hf_token, ) async with self._lock: self.sessions[session_id] = agent_session # Start the agent loop task task = asyncio.create_task( self._run_session(session_id, submission_queue, event_queue, tool_router) ) agent_session.task = task logger.info(f"Created session {session_id} for user {user_id}") return session_id async def seed_from_summary(self, session_id: str, messages: list[dict]) -> int: """Rehydrate a session from cached prior messages via summarization. Runs the standard summarization prompt (same one compaction uses) over the provided messages, then seeds the new session's context with that summary. Tool-call pairing concerns disappear because the output is plain text. Returns the number of messages summarized. """ from litellm import Message from agent.context_manager.manager import _RESTORE_PROMPT, summarize_messages agent_session = self.sessions.get(session_id) if not agent_session: raise ValueError(f"Session {session_id} not found") # Parse into Message objects, tolerating malformed entries. parsed: list[Message] = [] for raw in messages: if raw.get("role") == "system": continue # the new session has its own system prompt try: parsed.append(Message.model_validate(raw)) except Exception as e: logger.warning("Dropping malformed message during seed: %s", e) if not parsed: return 0 session = agent_session.session # Pass the real tool specs so the summarizer sees what the agent # actually has — otherwise Anthropic's modify_params injects a # dummy tool and the summarizer editorializes that the original # tool calls were fabricated. tool_specs = None try: tool_specs = agent_session.tool_router.get_tool_specs_for_llm() except Exception: pass try: summary, _ = await summarize_messages( parsed, model_name=session.config.model_name, hf_token=session.hf_token, max_tokens=4000, prompt=_RESTORE_PROMPT, tool_specs=tool_specs, ) except Exception as e: logger.error("Summary call failed during seed: %s", e) raise seed = Message( role="user", content=( "[SYSTEM: Your prior memory of this conversation — written " "in your own voice right before restart. Continue from here.]\n\n" + (summary or "(no summary returned)") ), ) session.context_manager.items.append(seed) return len(parsed) @staticmethod async def _cleanup_sandbox(session: Session) -> None: """Delete the sandbox Space if one was created for this session.""" sandbox = getattr(session, "sandbox", None) if sandbox and getattr(sandbox, "_owns_space", False): try: logger.info(f"Deleting sandbox {sandbox.space_id}...") await asyncio.to_thread(sandbox.delete) except Exception as e: logger.warning(f"Failed to delete sandbox {sandbox.space_id}: {e}") async def _run_session( self, session_id: str, submission_queue: asyncio.Queue, event_queue: asyncio.Queue, tool_router: ToolRouter, ) -> None: """Run the agent loop for a session and broadcast events via EventBroadcaster.""" agent_session = self.sessions.get(session_id) if not agent_session: logger.error(f"Session {session_id} not found") return session = agent_session.session # Start event broadcaster task broadcaster = EventBroadcaster(event_queue) agent_session.broadcaster = broadcaster broadcast_task = asyncio.create_task(broadcaster.run()) try: async with tool_router: # Send ready event await session.send_event( Event(event_type="ready", data={"message": "Agent initialized"}) ) while session.is_running: try: # Wait for submission with timeout to allow checking is_running submission = await asyncio.wait_for( submission_queue.get(), timeout=1.0 ) agent_session.is_processing = True try: should_continue = await process_submission(session, submission) finally: agent_session.is_processing = False if not should_continue: break except asyncio.TimeoutError: continue except asyncio.CancelledError: logger.info(f"Session {session_id} cancelled") break except Exception as e: logger.error(f"Error in session {session_id}: {e}") await session.send_event( Event(event_type="error", data={"error": str(e)}) ) finally: broadcast_task.cancel() try: await broadcast_task except asyncio.CancelledError: pass await self._cleanup_sandbox(session) async with self._lock: if session_id in self.sessions: self.sessions[session_id].is_active = False logger.info(f"Session {session_id} ended") async def submit(self, session_id: str, operation: Operation) -> bool: """Submit an operation to a session.""" async with self._lock: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: logger.warning(f"Session {session_id} not found or inactive") return False submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation) await agent_session.submission_queue.put(submission) return True async def submit_user_input(self, session_id: str, text: str) -> bool: """Submit user input to a session.""" operation = Operation(op_type=OpType.USER_INPUT, data={"text": text}) return await self.submit(session_id, operation) async def submit_approval( self, session_id: str, approvals: list[dict[str, Any]] ) -> bool: """Submit tool approvals to a session.""" operation = Operation( op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals} ) return await self.submit(session_id, operation) async def interrupt(self, session_id: str) -> bool: """Interrupt a session by signalling cancellation directly (bypasses queue).""" agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False agent_session.session.cancel() return True async def undo(self, session_id: str) -> bool: """Undo last turn in a session.""" operation = Operation(op_type=OpType.UNDO) return await self.submit(session_id, operation) async def truncate(self, session_id: str, user_message_index: int) -> bool: """Truncate conversation to before a specific user message (direct, no queue).""" async with self._lock: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False return agent_session.session.context_manager.truncate_to_user_message(user_message_index) async def compact(self, session_id: str) -> bool: """Compact context in a session.""" operation = Operation(op_type=OpType.COMPACT) return await self.submit(session_id, operation) async def shutdown_session(self, session_id: str) -> bool: """Shutdown a specific session.""" operation = Operation(op_type=OpType.SHUTDOWN) success = await self.submit(session_id, operation) if success: async with self._lock: agent_session = self.sessions.get(session_id) if agent_session and agent_session.task: # Wait for task to complete try: await asyncio.wait_for(agent_session.task, timeout=5.0) except asyncio.TimeoutError: agent_session.task.cancel() return success async def delete_session(self, session_id: str) -> bool: """Delete a session entirely.""" async with self._lock: agent_session = self.sessions.pop(session_id, None) if not agent_session: return False # Clean up sandbox Space before cancelling the task await self._cleanup_sandbox(agent_session.session) # Cancel the task if running if agent_session.task and not agent_session.task.done(): agent_session.task.cancel() try: await agent_session.task except asyncio.CancelledError: pass return True def get_session_owner(self, session_id: str) -> str | None: """Get the user_id that owns a session, or None if session doesn't exist.""" agent_session = self.sessions.get(session_id) if not agent_session: return None return agent_session.user_id def verify_session_access(self, session_id: str, user_id: str) -> bool: """Check if a user has access to a session. Returns True if: - The session exists AND the user owns it - The user_id is "dev" (dev mode bypass) """ owner = self.get_session_owner(session_id) if owner is None: return False if user_id == "dev" or owner == "dev": return True return owner == user_id def get_session_info(self, session_id: str) -> dict[str, Any] | None: """Get information about a session.""" agent_session = self.sessions.get(session_id) if not agent_session: return None # Extract pending approval tools if any pending_approval = None pa = agent_session.session.pending_approval if pa and pa.get("tool_calls"): pending_approval = [] for tc in pa["tool_calls"]: import json try: args = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError): args = {} pending_approval.append({ "tool": tc.function.name, "tool_call_id": tc.id, "arguments": args, }) return { "session_id": session_id, "created_at": agent_session.created_at.isoformat(), "is_active": agent_session.is_active, "is_processing": agent_session.is_processing, "message_count": len(agent_session.session.context_manager.items), "user_id": agent_session.user_id, "pending_approval": pending_approval, "model": agent_session.session.config.model_name, } def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: """List sessions, optionally filtered by user. Args: user_id: If provided, only return sessions owned by this user. If "dev", return all sessions (dev mode). """ results = [] for sid in self.sessions: info = self.get_session_info(sid) if not info: continue if user_id and user_id != "dev" and info.get("user_id") != user_id: continue results.append(info) return results @property def active_session_count(self) -> int: """Get count of active sessions.""" return sum(1 for s in self.sessions.values() if s.is_active) # Global session manager instance session_manager = SessionManager() ================================================ FILE: backend/start.sh ================================================ #!/bin/bash # Entrypoint for HF Spaces dev mode compatibility. # Dev mode spawns CMD multiple times simultaneously on restart. # Only the first instance can bind port 7860 — the rest must exit # with code 0 so the dev mode daemon doesn't mark the app as crashed. # Run uvicorn; if it fails due to port conflict, exit cleanly. uvicorn main:app --host 0.0.0.0 --port 7860 EXIT_CODE=$? if [ $EXIT_CODE -ne 0 ]; then # Check if this was a port-in-use failure (another instance already running) echo "uvicorn exited with code $EXIT_CODE, exiting gracefully." exit 0 fi ================================================ FILE: backend/user_quotas.py ================================================ """In-memory daily quota for Claude session creations. Tracks per-user Claude session starts against a daily cap derived from the user's HF plan. Caps reset at UTC midnight; the store itself is in-process and wipes on restart (deliberate — the cost of occasional over-subsidy at restart is much lower than running a DB). Unit: session *creations*, not messages. A user who selects Claude in a new session consumes one quota point; switching an existing Claude session to Claude again doesn't (`AgentSession.claude_counted` guards that). Cap tiers: free user → CLAUDE_FREE_DAILY (1) pro / org → CLAUDE_PRO_DAILY (20) """ import asyncio import os from datetime import UTC, datetime CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1")) CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20")) # user_id -> (day_utc_iso, count_for_that_day) _claude_counts: dict[str, tuple[str, int]] = {} _lock = asyncio.Lock() def _today() -> str: return datetime.now(UTC).date().isoformat() def daily_cap_for(plan: str | None) -> int: """Return the daily Claude-session cap for the given plan.""" return CLAUDE_FREE_DAILY if (plan or "free") == "free" else CLAUDE_PRO_DAILY async def get_claude_used_today(user_id: str) -> int: """Return today's Claude session count for the user (0 if none / stale day).""" async with _lock: entry = _claude_counts.get(user_id) if entry is None: return 0 day, count = entry if day != _today(): # Stale day — drop the entry so the first increment starts fresh. _claude_counts.pop(user_id, None) return 0 return count async def increment_claude(user_id: str) -> int: """Bump today's Claude session count for the user. Returns the new value.""" async with _lock: today = _today() day, count = _claude_counts.get(user_id, (today, 0)) if day != today: count = 0 count += 1 _claude_counts[user_id] = (today, count) return count async def refund_claude(user_id: str) -> None: """Decrement today's count — used when session creation fails after a successful gate.""" async with _lock: entry = _claude_counts.get(user_id) if entry is None: return day, count = entry if day != _today(): _claude_counts.pop(user_id, None) return new_count = max(0, count - 1) if new_count == 0: _claude_counts.pop(user_id, None) else: _claude_counts[user_id] = (day, new_count) def _reset_for_tests() -> None: """Test-only: clear the in-memory store.""" _claude_counts.clear() ================================================ FILE: configs/main_agent_config.json ================================================ { "model_name": "bedrock/us.anthropic.claude-opus-4-6-v1", "save_sessions": true, "session_dataset_repo": "akseljoonas/hf-agent-sessions", "yolo_mode": false, "confirm_cpu_jobs": true, "auto_file_upload": true, "mcpServers": { "hf-mcp-server": { "transport": "http", "url": "https://huggingface.co/mcp?login" } } } ================================================ FILE: frontend/eslint.config.js ================================================ import js from '@eslint/js' import globals from 'globals' import reactHooks from 'eslint-plugin-react-hooks' import reactRefresh from 'eslint-plugin-react-refresh' import tseslint from 'typescript-eslint' export default tseslint.config( { ignores: ['dist'] }, { extends: [js.configs.recommended, ...tseslint.configs.recommended], files: ['**/*.{ts,tsx}'], languageOptions: { ecmaVersion: 2020, globals: globals.browser, }, plugins: { 'react-hooks': reactHooks, 'react-refresh': reactRefresh, }, rules: { ...reactHooks.configs.recommended.rules, 'react-refresh/only-export-components': [ 'warn', { allowConstantExport: true }, ], }, }, ) ================================================ FILE: frontend/index.html ================================================ ML Intern
================================================ FILE: frontend/package.json ================================================ { "name": "hf-agent-frontend", "private": true, "version": "1.0.0", "type": "module", "scripts": { "dev": "vite", "build": "tsc -b && vite build", "lint": "eslint .", "preview": "vite preview" }, "dependencies": { "@ai-sdk/react": "^3.0.93", "@emotion/react": "^11.13.0", "@emotion/styled": "^11.13.0", "@mui/icons-material": "^6.1.0", "@mui/material": "^6.1.0", "ai": "^6.0.91", "react": "^18.3.1", "react-dom": "^18.3.1", "react-markdown": "^9.0.1", "react-syntax-highlighter": "^16.1.0", "remark-gfm": "^4.0.1", "zustand": "^5.0.0" }, "devDependencies": { "@eslint/js": "^9.13.0", "@types/react": "^18.3.12", "@types/react-dom": "^18.3.1", "@types/react-syntax-highlighter": "^15.5.13", "@vitejs/plugin-react": "^4.3.3", "eslint": "^9.13.0", "eslint-plugin-react-hooks": "^5.0.0", "eslint-plugin-react-refresh": "^0.4.13", "globals": "^15.11.0", "typescript": "~5.6.2", "typescript-eslint": "^8.10.0", "vite": "^5.4.10" } } ================================================ FILE: frontend/src/App.tsx ================================================ import { Box } from '@mui/material'; import AppLayout from '@/components/Layout/AppLayout'; import { useAuth } from '@/hooks/useAuth'; function App() { // Non-blocking auth check — fires in background, updates store when done. // If auth fails later, apiFetch redirects to /auth/login. useAuth(); return ( ); } export default App; ================================================ FILE: frontend/src/components/Chat/ActivityStatusBar.tsx ================================================ import { Box, Typography } from '@mui/material'; import { keyframes } from '@mui/system'; import { useAgentStore, type ActivityStatus } from '@/store/agentStore'; const shimmer = keyframes` 0% { background-position: -100% center; } 50% { background-position: 200% center; } 100% { background-position: -100% center; } `; const TOOL_LABELS: Record = { sandbox_create: 'Creating sandbox for code development, this might take 1-2 minutes', bash: 'Running command in sandbox', hf_jobs: 'Running a GPU job, this might take a while', hf_repo_files: 'Uploading file', hf_repo_git: 'Git operation', hf_inspect_dataset: 'Inspecting dataset', hf_search: 'Searching', plan_tool: 'Planning', research: 'Researching', }; /** Format raw research log into a clean status label. */ function formatResearchStatus(raw: string): string { const s = raw.replace(/^▸\s*/, ''); const jsonStart = s.indexOf('{'); const toolName = jsonStart > 0 ? s.slice(0, jsonStart).trim() : s.trim(); let args: Record = {}; if (jsonStart > 0) { const jsonStr = s.slice(jsonStart); try { const parsed = JSON.parse(jsonStr); for (const [k, v] of Object.entries(parsed)) { if (typeof v === 'string') args[k] = v; } } catch { // JSON is likely truncated — extract complete "key": "value" pairs for (const m of jsonStr.matchAll(/"(\w+)":\s*"([^"]*)"/g)) { args[m[1]] = m[2]; } // Also try to extract a truncated value for known keys if not found yet if (!args.query && !args.arxiv_id) { const partial = jsonStr.match(/"(query|arxiv_id)":\s*"([^"]*)/); if (partial) args[partial[1]] = partial[2]; } } } if (toolName === 'github_find_examples') { const d = (args.keyword) || (args.repo); return d ? `Finding examples: ${d}` : 'Finding examples'; } if (toolName === 'github_read_file') { const f = ((args.path) || '').split('/').pop(); return f ? `Reading ${f}` : 'Reading file'; } if (toolName === 'explore_hf_docs') { const d = (args.endpoint) || (args.query); return d ? `Exploring docs: ${d}` : 'Exploring docs'; } if (toolName === 'fetch_hf_docs') { const p = ((args.url) || '').split('/').pop()?.replace(/\.md$/, ''); return p ? `Reading docs: ${p}` : 'Fetching docs'; } if (toolName === 'hf_inspect_dataset') { const d = args.dataset as string; return d ? `Inspecting dataset: ${d}` : 'Inspecting dataset'; } if (toolName === 'hf_papers') { const op = args.operation as string; const detail = (args.query) || (args.arxiv_id) || (args.positive_ids); const opLabels: Record = { trending: 'Browsing trending papers', search: 'Searching papers', paper_details: 'Reading paper details', read_paper: 'Reading paper', citation_graph: 'Tracing citations', snippet_search: 'Searching paper passages', recommend: 'Finding similar papers', find_datasets: 'Finding paper datasets', find_models: 'Finding paper models', find_collections: 'Finding paper collections', find_all_resources: 'Finding paper resources', }; const base = (op && opLabels[op]) || 'Searching papers'; return detail ? `${base}: ${detail}` : base; } if (toolName === 'find_hf_api') { const d = (args.query) || (args.tag); return d ? `Finding API: ${d}` : 'Finding API endpoints'; } if (toolName === 'hf_repo_files') { const d = (args.repo_id) || (args.repo); return d ? `Reading ${d} files` : 'Reading repo files'; } return 'Researching'; } function statusLabel(status: ActivityStatus): string { switch (status.type) { case 'thinking': return 'Thinking'; case 'streaming': return 'Writing'; case 'tool': { if (status.toolName === 'research' && status.description) { return formatResearchStatus(status.description); } const base = status.description || TOOL_LABELS[status.toolName] || `Running ${status.toolName}`; if (status.toolName === 'bash' && status.description && /install/i.test(status.description)) { return `${base} — this can take a few minutes, sit tight`; } return base; } case 'waiting-approval': return 'Waiting for approval'; case 'cancelled': return 'What should the agent do instead?'; default: return ''; } } export default function ActivityStatusBar() { const activityStatus = useAgentStore(s => s.activityStatus); if (activityStatus.type === 'idle') return null; const label = statusLabel(activityStatus); return ( {label}{activityStatus.type !== 'cancelled' && '…'} ); } ================================================ FILE: frontend/src/components/Chat/AssistantMessage.tsx ================================================ import { useMemo } from 'react'; import { Box, Stack, Typography } from '@mui/material'; import MarkdownContent from './MarkdownContent'; import ToolCallGroup from './ToolCallGroup'; import type { UIMessage } from 'ai'; import type { MessageMeta } from '@/types/agent'; interface AssistantMessageProps { message: UIMessage; isStreaming?: boolean; approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; } /** * Groups consecutive tool parts together so they render as a single * ToolCallGroup (visually identical to the old segments approach). */ type DynamicToolPart = Extract; function groupParts(parts: UIMessage['parts']) { const groups: Array< | { kind: 'text'; text: string; idx: number } | { kind: 'tools'; tools: DynamicToolPart[]; idx: number } > = []; for (let i = 0; i < parts.length; i++) { const part = parts[i]; if (part.type === 'text') { groups.push({ kind: 'text', text: part.text, idx: i }); } else if (part.type === 'dynamic-tool') { const toolPart = part as DynamicToolPart; const last = groups[groups.length - 1]; if (last?.kind === 'tools') { last.tools.push(toolPart); } else { groups.push({ kind: 'tools', tools: [toolPart], idx: i }); } } // step-start, step-end, etc. are ignored visually } return groups; } export default function AssistantMessage({ message, isStreaming = false, approveTools }: AssistantMessageProps) { const groups = useMemo(() => groupParts(message.parts), [message.parts]); // Find the last text group index for streaming cursor let lastTextIdx = -1; for (let i = groups.length - 1; i >= 0; i--) { if (groups[i].kind === 'text') { lastTextIdx = i; break; } } const meta = message.metadata as MessageMeta | undefined; const timeStr = meta?.createdAt ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) : null; if (groups.length === 0) return null; return ( Assistant {timeStr && ( {timeStr} )} {groups.map((group, i) => { if (group.kind === 'text' && group.text) { return ( ); } if (group.kind === 'tools' && group.tools.length > 0) { return ( ); } return null; })} ); } ================================================ FILE: frontend/src/components/Chat/ChatInput.tsx ================================================ import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; import StopIcon from '@mui/icons-material/Stop'; import { apiFetch } from '@/utils/api'; import { useUserQuota } from '@/hooks/useUserQuota'; import ClaudeCapDialog from '@/components/ClaudeCapDialog'; import { useAgentStore } from '@/store/agentStore'; import { FIRST_FREE_MODEL_PATH } from '@/utils/model'; // Model configuration interface ModelOption { id: string; name: string; description: string; modelPath: string; avatarUrl: string; recommended?: boolean; } const getHfAvatarUrl = (modelId: string) => { const org = modelId.split('/')[0]; return `https://huggingface.co/api/avatars/${org}`; }; const MODEL_OPTIONS: ModelOption[] = [ { id: 'kimi-k2.6', name: 'Kimi K2.6', description: 'Novita', modelPath: 'moonshotai/Kimi-K2.6', avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.6'), recommended: true, }, { id: 'claude-opus', name: 'Claude Opus 4.6', description: 'Anthropic', modelPath: 'anthropic/claude-opus-4-6', avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', recommended: true, }, { id: 'minimax-m2.7', name: 'MiniMax M2.7', description: 'Novita', modelPath: 'MiniMaxAI/MiniMax-M2.7', avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.7'), }, { id: 'glm-5.1', name: 'GLM 5.1', description: 'Together', modelPath: 'zai-org/GLM-5.1', avatarUrl: getHfAvatarUrl('zai-org/GLM-5.1'), }, ]; const findModelByPath = (path: string): ModelOption | undefined => { return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); }; interface ChatInputProps { sessionId?: string; onSend: (text: string) => void; onStop?: () => void; isProcessing?: boolean; disabled?: boolean; placeholder?: string; } const isClaudeModel = (m: ModelOption) => m.modelPath.startsWith('anthropic/'); const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0]; export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); const inputRef = useRef(null); const [selectedModelId, setSelectedModelId] = useState(MODEL_OPTIONS[0].id); const [modelAnchorEl, setModelAnchorEl] = useState(null); const { quota, refresh: refreshQuota } = useUserQuota(); // The daily-cap dialog is triggered from two places: (a) a 429 returned // from the chat transport when the user tries to send on Opus over cap — // surfaced via the agent-store flag — and (b) nothing else right now // (switching models is free). Keeping the open state in the store means // the hook layer can flip it without threading props through. const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted); const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted); const lastSentRef = useRef(''); // Model is per-session: fetch this tab's current model every time the // session changes. Other tabs keep their own selections independently. useEffect(() => { if (!sessionId) return; let cancelled = false; apiFetch(`/api/session/${sessionId}`) .then((res) => (res.ok ? res.json() : null)) .then((data) => { if (cancelled) return; if (data?.model) { const model = findModelByPath(data.model); if (model) setSelectedModelId(model.id); } }) .catch(() => { /* ignore */ }); return () => { cancelled = true; }; }, [sessionId]); const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; // Auto-focus the textarea when the session becomes ready useEffect(() => { if (!disabled && !isProcessing && inputRef.current) { inputRef.current.focus(); } }, [disabled, isProcessing]); const handleSend = useCallback(() => { if (input.trim() && !disabled) { lastSentRef.current = input; onSend(input); setInput(''); } }, [input, disabled, onSend]); // When the chat transport reports a Claude-quota 429, restore the typed // text so the user doesn't lose their message. useEffect(() => { if (claudeQuotaExhausted && lastSentRef.current) { setInput(lastSentRef.current); } }, [claudeQuotaExhausted]); // Refresh the quota display whenever the session changes (user might // have started another tab that spent quota). useEffect(() => { if (sessionId) refreshQuota(); // eslint-disable-next-line react-hooks/exhaustive-deps }, [sessionId]); const handleKeyDown = useCallback( (e: KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); handleSend(); } }, [handleSend] ); const handleModelClick = (event: React.MouseEvent) => { setModelAnchorEl(event.currentTarget); }; const handleModelClose = () => { setModelAnchorEl(null); }; const handleSelectModel = async (model: ModelOption) => { handleModelClose(); if (!sessionId) return; try { const res = await apiFetch(`/api/session/${sessionId}/model`, { method: 'POST', body: JSON.stringify({ model: model.modelPath }), }); if (res.ok) setSelectedModelId(model.id); } catch { /* ignore */ } }; // Dialog close: just clear the flag. The typed text is already restored. const handleCapDialogClose = useCallback(() => { setClaudeQuotaExhausted(false); }, [setClaudeQuotaExhausted]); // "Use a free model" — switch the current session to Kimi (or the first // non-Anthropic option) and auto-retry the send that tripped the cap. const handleUseFreeModel = useCallback(async () => { setClaudeQuotaExhausted(false); if (!sessionId) return; const free = MODEL_OPTIONS.find(m => m.modelPath === FIRST_FREE_MODEL_PATH) ?? firstFreeModel(); try { const res = await apiFetch(`/api/session/${sessionId}/model`, { method: 'POST', body: JSON.stringify({ model: free.modelPath }), }); if (res.ok) { setSelectedModelId(free.id); const retryText = lastSentRef.current; if (retryText) { onSend(retryText); setInput(''); lastSentRef.current = ''; } } } catch { /* ignore */ } }, [sessionId, onSend, setClaudeQuotaExhausted]); // Hide the chip until the user has actually burned quota — an unused // Opus session shouldn't populate a counter. const claudeChip = (() => { if (!quota || quota.claudeUsedToday === 0) return null; if (quota.plan === 'free') { return quota.claudeRemaining > 0 ? 'Free today' : 'Pro only'; } return `${quota.claudeUsedToday}/${quota.claudeDailyCap} today`; })(); return ( setInput(e.target.value)} onKeyDown={handleKeyDown} placeholder={placeholder} disabled={disabled || isProcessing} variant="standard" inputRef={inputRef} InputProps={{ disableUnderline: true, sx: { color: 'var(--text)', fontSize: '15px', fontFamily: 'inherit', padding: 0, lineHeight: 1.5, minHeight: { xs: '44px', md: '56px' }, alignItems: 'flex-start', } }} sx={{ flex: 1, '& .MuiInputBase-root': { p: 0, backgroundColor: 'transparent', }, '& textarea': { resize: 'none', padding: '0 !important', } }} /> {isProcessing ? ( ) : ( )} {/* Powered By Badge */} powered by {selectedModel.name} {selectedModel.name} {/* Model Selection Menu */} {MODEL_OPTIONS.map((model) => ( handleSelectModel(model)} selected={selectedModelId === model.id} sx={{ py: 1.5, '&.Mui-selected': { bgcolor: 'rgba(255,255,255,0.05)', } }} > {model.name} {model.name} {model.recommended && ( )} {isClaudeModel(model) && claudeChip && ( )} } secondary={model.description} secondaryTypographyProps={{ sx: { fontSize: '12px', color: 'var(--muted-text)' } }} /> ))} ); } ================================================ FILE: frontend/src/components/Chat/ExpiredBanner.tsx ================================================ /** * Shown inline in a chat when the backend no longer recognizes the * session id (typically: Space was restarted). Lets the user catch the * agent up with a summary of the prior conversation, or start over. */ import { useState, useCallback } from 'react'; import { Box, Button, CircularProgress, Typography } from '@mui/material'; import { apiFetch } from '@/utils/api'; import { useSessionStore } from '@/store/sessionStore'; import { useAgentStore } from '@/store/agentStore'; import { loadBackendMessages } from '@/lib/backend-message-store'; import { loadMessages } from '@/lib/chat-message-store'; import { uiMessagesToLLMMessages } from '@/lib/convert-llm-messages'; import { logger } from '@/utils/logger'; interface Props { sessionId: string; } export default function ExpiredBanner({ sessionId }: Props) { const { renameSession, deleteSession } = useSessionStore(); const [busy, setBusy] = useState<'catch-up' | 'start-over' | null>(null); const [error, setError] = useState(null); const handleCatchUp = useCallback(async () => { setBusy('catch-up'); setError(null); try { // Prefer the raw backend-message cache; fall back to reconstructing // from UIMessages (for sessions that predate the backend cache). let messages = loadBackendMessages(sessionId); if (!messages || messages.length === 0) { const uiMsgs = loadMessages(sessionId); if (uiMsgs.length > 0) messages = uiMessagesToLLMMessages(uiMsgs); } if (!messages || messages.length === 0) { setError('Nothing to summarize from this chat.'); setBusy(null); return; } const res = await apiFetch('/api/session/restore-summary', { method: 'POST', body: JSON.stringify({ messages }), }); if (!res.ok) throw new Error(`restore-summary failed: ${res.status}`); const data = await res.json(); const newId = data.session_id as string | undefined; if (!newId) throw new Error('no session_id in response'); useAgentStore.getState().clearSessionState(sessionId); renameSession(sessionId, newId); } catch (e) { logger.warn('Catch-up failed:', e); setError("Couldn't catch up — try starting over."); setBusy(null); } }, [sessionId, renameSession]); const handleStartOver = useCallback(() => { setBusy('start-over'); useAgentStore.getState().clearSessionState(sessionId); deleteSession(sessionId); }, [sessionId, deleteSession]); return ( Where were we? Let me skim the conversation so far and pick up right where we left off — or we can start something new. {error && ( {error} )} ); } ================================================ FILE: frontend/src/components/Chat/MarkdownContent.tsx ================================================ import { useMemo, useRef, useState, useEffect } from 'react'; import { Box } from '@mui/material'; import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; import type { SxProps, Theme } from '@mui/material/styles'; interface MarkdownContentProps { content: string; sx?: SxProps; /** When true, shows a blinking cursor and throttles renders. */ isStreaming?: boolean; } /** Shared markdown styles — adapts to light/dark via CSS variables. */ const markdownSx: SxProps = { fontSize: '0.925rem', lineHeight: 1.7, color: 'var(--text)', wordBreak: 'break-word', '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } }, '& h1, & h2, & h3, & h4': { mt: 2.5, mb: 1, fontWeight: 600, lineHeight: 1.3 }, '& h1': { fontSize: '1.35rem' }, '& h2': { fontSize: '1.15rem' }, '& h3': { fontSize: '1.05rem' }, '& pre': { bgcolor: 'var(--code-bg)', p: 2, borderRadius: 2, overflow: 'auto', fontSize: '0.82rem', lineHeight: 1.6, border: '1px solid var(--tool-border)', my: 2, }, '& code': { bgcolor: 'var(--hover-bg)', px: 0.75, py: 0.25, borderRadius: 0.5, fontSize: '0.84rem', fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', }, '& pre code': { bgcolor: 'transparent', p: 0 }, '& a': { color: 'var(--accent-yellow)', textDecoration: 'none', fontWeight: 500, '&:hover': { textDecoration: 'underline' }, }, '& ul, & ol': { pl: 3, my: 1 }, '& li': { mb: 0.5 }, '& li::marker': { color: 'var(--muted-text)' }, '& blockquote': { borderLeft: '3px solid var(--accent-yellow)', pl: 2, ml: 0, my: 1.5, color: 'var(--muted-text)', fontStyle: 'italic', }, '& table': { borderCollapse: 'collapse', width: '100%', my: 2, fontSize: '0.85rem', display: 'block', overflowX: 'auto', WebkitOverflowScrolling: 'touch', }, '& thead': { position: 'sticky', top: 0, }, '& th': { borderBottom: '2px solid var(--border-hover)', bgcolor: 'var(--hover-bg)', textAlign: 'left', px: 1.5, py: 0.75, fontWeight: 600, whiteSpace: 'nowrap', }, '& td': { borderBottom: '1px solid var(--tool-border)', px: 1.5, py: 0.75, }, '& tr:nth-of-type(even) td': { bgcolor: 'color-mix(in srgb, var(--hover-bg) 50%, transparent)', }, '& hr': { border: 'none', borderTop: '1px solid var(--border)', my: 2, }, '& img': { maxWidth: '100%', borderRadius: 2, }, }; /** * Throttled content for streaming: render the full markdown through * ReactMarkdown but only re-parse every ~80ms to avoid layout thrashing. * This is the Claude approach — always render as markdown, never split * into raw text. The parser handles incomplete tables gracefully. */ function useThrottledValue(value: string, isStreaming: boolean, intervalMs = 80): string { const [throttled, setThrottled] = useState(value); const lastUpdate = useRef(0); const pending = useRef | null>(null); const latestValue = useRef(value); latestValue.current = value; useEffect(() => { if (!isStreaming) { // Not streaming — always use latest value immediately setThrottled(value); return; } const now = Date.now(); const elapsed = now - lastUpdate.current; if (elapsed >= intervalMs) { // Enough time passed — update immediately setThrottled(value); lastUpdate.current = now; } else { // Schedule an update for the remaining time if (pending.current) clearTimeout(pending.current); pending.current = setTimeout(() => { setThrottled(latestValue.current); lastUpdate.current = Date.now(); pending.current = null; }, intervalMs - elapsed); } return () => { if (pending.current) clearTimeout(pending.current); }; }, [value, isStreaming, intervalMs]); // When streaming ends, flush immediately useEffect(() => { if (!isStreaming) { setThrottled(latestValue.current); } }, [isStreaming]); return throttled; } export default function MarkdownContent({ content, sx, isStreaming = false }: MarkdownContentProps) { // Throttle re-parses during streaming to ~12fps (every 80ms) const displayContent = useThrottledValue(content, isStreaming); const remarkPlugins = useMemo(() => [remarkGfm], []); return ( {displayContent} ); } ================================================ FILE: frontend/src/components/Chat/MessageBubble.tsx ================================================ import UserMessage from './UserMessage'; import AssistantMessage from './AssistantMessage'; import type { UIMessage } from 'ai'; interface MessageBubbleProps { message: UIMessage; isLastTurn?: boolean; onUndoTurn?: () => void; onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise; isProcessing?: boolean; isStreaming?: boolean; approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; } export default function MessageBubble({ message, isLastTurn = false, onUndoTurn, onEditAndRegenerate, isProcessing = false, isStreaming = false, approveTools, }: MessageBubbleProps) { if (message.role === 'user') { return ( ); } if (message.role === 'assistant') { return ( ); } return null; } ================================================ FILE: frontend/src/components/Chat/MessageList.tsx ================================================ import { useCallback, useEffect, useRef, useMemo } from 'react'; import { Box, Stack, Typography } from '@mui/material'; import MessageBubble from './MessageBubble'; import ActivityStatusBar from './ActivityStatusBar'; import { useAgentStore } from '@/store/agentStore'; import type { UIMessage } from 'ai'; interface MessageListProps { messages: UIMessage[]; isProcessing: boolean; approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; onUndoLastTurn: () => void | Promise; onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise; } function getGreeting(): string { const h = new Date().getHours(); if (h < 12) return 'Morning'; if (h < 17) return 'Afternoon'; return 'Evening'; } function WelcomeGreeting() { const { user } = useAgentStore(); const firstName = user?.name?.split(' ')[0] || user?.username; const greeting = firstName ? `${getGreeting()}, ${firstName}` : getGreeting(); return ( {greeting} Let's build something impressive? ); } export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn, onEditAndRegenerate }: MessageListProps) { const scrollContainerRef = useRef(null); const stickToBottom = useRef(true); const scrollToBottom = useCallback(() => { const el = scrollContainerRef.current; if (el) el.scrollTop = el.scrollHeight; }, []); useEffect(() => { const el = scrollContainerRef.current; if (!el) return; const onScroll = () => { const distFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight; stickToBottom.current = distFromBottom < 80; }; el.addEventListener('scroll', onScroll, { passive: true }); return () => el.removeEventListener('scroll', onScroll); }, []); useEffect(() => { if (stickToBottom.current) scrollToBottom(); }, [messages, isProcessing, scrollToBottom]); useEffect(() => { const el = scrollContainerRef.current; if (!el) return; const observer = new MutationObserver(() => { if (stickToBottom.current) el.scrollTop = el.scrollHeight; }); observer.observe(el, { childList: true, subtree: true, characterData: true }); return () => observer.disconnect(); }, []); const lastUserMsgId = useMemo(() => { for (let i = messages.length - 1; i >= 0; i--) { if (messages[i].role === 'user') return messages[i].id; } return null; }, [messages]); // The last assistant message is "streaming" when we're processing const lastAssistantId = useMemo(() => { for (let i = messages.length - 1; i >= 0; i--) { if (messages[i].role === 'assistant') return messages[i].id; } return null; }, [messages]); return ( {messages.length === 0 && !isProcessing ? ( ) : ( messages.map((msg) => ( )) )}
); } ================================================ FILE: frontend/src/components/Chat/ThinkingIndicator.tsx ================================================ import { Box, Typography } from '@mui/material'; /** Pulsing dots shown while the agent is processing. */ export default function ThinkingIndicator() { return ( Thinking ); } ================================================ FILE: frontend/src/components/Chat/ToolCallGroup.tsx ================================================ import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material'; import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline'; import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline'; import OpenInNewIcon from '@mui/icons-material/OpenInNew'; import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty'; import LaunchIcon from '@mui/icons-material/Launch'; import SendIcon from '@mui/icons-material/Send'; import BlockIcon from '@mui/icons-material/Block'; import { useAgentStore, type ResearchAgentState } from '@/store/agentStore'; import { useLayoutStore } from '@/store/layoutStore'; import { logger } from '@/utils/logger'; import { RESEARCH_MAX_STEPS } from '@/lib/research-store'; import type { UIMessage } from 'ai'; // --------------------------------------------------------------------------- // Type helpers — extract the dynamic-tool part type from UIMessage // --------------------------------------------------------------------------- type DynamicToolPart = Extract; type ToolPartState = DynamicToolPart['state']; /** Check if a tool part was cancelled (output-error with cancellation message). */ function isCancelledTool(tool: DynamicToolPart): boolean { return tool.state === 'output-error' && typeof (tool as Record).errorText === 'string' && ((tool as Record).errorText as string).includes('Cancelled by user'); } interface ToolCallGroupProps { tools: DynamicToolPart[]; approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>) => Promise; } // --------------------------------------------------------------------------- // Research sub-steps (inline under the research tool row) // --------------------------------------------------------------------------- /** Hook that forces a re-render every second while enabled — used so each * research card can compute its own elapsed seconds synchronously from * Date.now() without needing its own timer. */ function useSecondTick(enabled: boolean): void { const [, setTick] = useState(0); useEffect(() => { if (!enabled) return; const id = setInterval(() => setTick(t => t + 1), 1000); return () => clearInterval(id); }, [enabled]); } /** Compute elapsed seconds from startedAt (or null). Call under useSecondTick. */ function computeElapsed(startedAt: number | null): number | null { if (startedAt === null) return null; return Math.round((Date.now() - startedAt) / 1000); } /** Format token count like the CLI: "12.4k" or "800". */ function formatTokens(tokens: number): string { return tokens >= 1000 ? `${(tokens / 1000).toFixed(1)}k` : String(tokens); } /** Format elapsed seconds like the CLI: "18s" or "2m 5s". */ function formatElapsed(seconds: number): string { if (seconds < 60) return `${seconds}s`; return `${Math.floor(seconds / 60)}m ${seconds % 60}s`; } /** Build the research stats chip label. */ function researchChipLabel( stats: { toolCount: number; tokenCount: number; startedAt: number | null; finalElapsed: number | null }, liveElapsed: number | null, ): string | null { const elapsed = stats.finalElapsed ?? liveElapsed; if (elapsed === null && stats.toolCount === 0) return null; const parts: string[] = []; if (stats.startedAt !== null) parts.push('running'); if (stats.toolCount > 0) parts.push(`${stats.toolCount} tools`); if (stats.tokenCount > 0) parts.push(`${formatTokens(stats.tokenCount)} tokens`); if (elapsed !== null) parts.push(formatElapsed(elapsed)); return parts.join(' \u00B7 '); } /** Parse JSON args from a step string like "tool_name {json}" (may be truncated at 80 chars). */ function parseStepArgs(step: string): Record { const jsonStart = step.indexOf('{'); if (jsonStart < 0) return {}; const jsonStr = step.slice(jsonStart); try { const parsed = JSON.parse(jsonStr); const result: Record = {}; for (const [k, v] of Object.entries(parsed)) { if (typeof v === 'string') result[k] = v; } return result; } catch { // JSON likely truncated — extract key-value pairs via regex const result: Record = {}; // Match complete "key": "value" pairs for (const m of jsonStr.matchAll(/"(\w+)":\s*"([^"]*)"/g)) { result[m[1]] = m[2]; } // Match truncated trailing value: "key": "value... (no closing quote) if (Object.keys(result).length === 0 || !result.query) { const trunc = jsonStr.match(/"(\w+)":\s*"([^"]+)$/); if (trunc && !result[trunc[1]]) { result[trunc[1]] = trunc[2]; } } return result; } } /** Pretty labels for research sub-agent tool calls */ function formatResearchStep(raw: string): { label: string } { // Backend sends logs like "▸ tool_name {args}" — strip the prefix const step = raw.replace(/^▸\s*/, ''); const args = parseStepArgs(step); if (step.startsWith('github_find_examples')) { const detail = (args.keyword) || (args.repo); return { label: detail ? `Finding examples: ${detail}` : 'Finding examples' }; } if (step.startsWith('github_read_file')) { const path = (args.path) || ''; const filename = path.split('/').pop() || path; return { label: filename ? `Reading ${filename}` : 'Reading file' }; } if (step.startsWith('explore_hf_docs')) { const endpoint = (args.endpoint) || (args.query); return { label: endpoint ? `Exploring docs: ${endpoint}` : 'Exploring docs' }; } if (step.startsWith('fetch_hf_docs')) { const url = (args.url) || ''; const page = url.split('/').pop()?.replace(/\.md$/, ''); return { label: page ? `Reading docs: ${page}` : 'Fetching docs' }; } if (step.startsWith('hf_inspect_dataset')) { const dataset = (args.dataset); return { label: dataset ? `Inspecting dataset: ${dataset}` : 'Inspecting dataset' }; } if (step.startsWith('hf_papers')) { const op = args.operation as string; const detail = (args.query) || (args.arxiv_id); const opLabels: Record = { trending: 'Browsing trending papers', search: 'Searching papers', paper_details: 'Reading paper details', read_paper: 'Reading paper', citation_graph: 'Tracing citations', snippet_search: 'Searching paper snippets', recommend: 'Finding related papers', find_datasets: 'Finding paper datasets', find_models: 'Finding paper models', find_collections: 'Finding paper collections', find_all_resources: 'Finding paper resources', }; const base = (op && opLabels[op]) || 'Searching papers'; return { label: detail ? `${base}: ${detail}` : base }; } if (step.startsWith('find_hf_api')) { const detail = (args.query) || (args.tag); return { label: detail ? `Finding API: ${detail}` : 'Finding API endpoints' }; } if (step.startsWith('hf_repo_files')) { const repo = (args.repo_id) || (args.repo); return { label: repo ? `Reading ${repo} files` : 'Reading repo files' }; } if (step.startsWith('read')) { const path = (args.path) || ''; const filename = path.split('/').pop(); return { label: filename ? `Reading ${filename}` : 'Reading file' }; } if (step.startsWith('bash')) { const cmd = args.command as string; const short = cmd && cmd.length > 40 ? cmd.slice(0, 40) + '...' : cmd; return { label: short ? `Running: ${short}` : 'Running command' }; } return { label: step.replace(/^▸\s*/, '') }; } /** Rolling display of research sub-tool calls for a single agent. */ function ResearchSteps({ steps }: { steps: string[] }) { const visible = steps.slice(-RESEARCH_MAX_STEPS); if (visible.length === 0) return null; return ( {visible.map((step, i) => { const { label } = formatResearchStep(step); const isLast = i === visible.length - 1; return ( {isLast ? ( ) : ( )} {label} ); })} ); } // --------------------------------------------------------------------------- // Hardware pricing ($/hr) — from HF Spaces & Jobs pricing // --------------------------------------------------------------------------- const HARDWARE_PRICING: Record = { 'cpu-basic': 'free', 'cpu-upgrade': '$0.03/hr', 't4-small': '$0.60/hr', 't4-medium': '$1.00/hr', 'a10g-small': '$1.05/hr', 'a10g-large': '$3.15/hr', 'a10g-largex2': '$6.30/hr', 'a10g-largex4': '$12.60/hr', 'a100-large': '$4.13/hr', 'a100x4': '$16.52/hr', 'a100x8': '$33.04/hr', 'l4x1': '$0.80/hr', 'l4x4': '$3.20/hr', 'l40sx1': '$1.80/hr', 'l40sx4': '$7.20/hr', 'l40sx8': '$14.40/hr', }; function costLabel(hardware: string): string | null { return HARDWARE_PRICING[hardware] || null; } // --------------------------------------------------------------------------- // Visual helpers // --------------------------------------------------------------------------- function StatusIcon({ state, cancelled, isRejected }: { state: ToolPartState; cancelled?: boolean; isRejected?: boolean }) { if (cancelled || isRejected) { return ; } switch (state) { case 'approval-requested': return ; case 'approval-responded': return ; case 'output-available': return ; case 'output-error': return ; case 'output-denied': return ; case 'input-streaming': case 'input-available': default: return ; } } function statusLabel(state: ToolPartState): string | null { switch (state) { case 'approval-requested': return 'awaiting approval'; case 'approval-responded': return 'approved'; case 'input-streaming': case 'input-available': return 'running'; case 'output-denied': return 'denied'; case 'output-error': return 'error'; default: return null; } } function statusColor(state: ToolPartState): string { switch (state) { case 'approval-requested': return 'var(--accent-yellow)'; case 'approval-responded': return 'var(--accent-green)'; case 'output-available': return 'var(--accent-green)'; case 'output-error': return 'var(--accent-red)'; case 'output-denied': return 'var(--muted-text)'; default: return 'var(--accent-yellow)'; } } // --------------------------------------------------------------------------- // Inline approval UI (per-tool) // --------------------------------------------------------------------------- function InlineApproval({ toolCallId, toolName, input, scriptLabel, onResolve, }: { toolCallId: string; toolName: string; input: unknown; scriptLabel: string; onResolve: (toolCallId: string, approved: boolean, feedback?: string) => void; }) { const [feedback, setFeedback] = useState(''); const args = input as Record | undefined; const { setPanel, getEditedScript } = useAgentStore(); const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); const hasEditedScript = !!getEditedScript(toolCallId); const handleScriptClick = useCallback(() => { if (toolName === 'hf_jobs' && args?.script) { const scriptContent = getEditedScript(toolCallId) || String(args.script); setPanel( { title: scriptLabel, script: { content: scriptContent, language: 'python' }, parameters: { tool_call_id: toolCallId } }, 'script', true, ); setRightPanelOpen(true); setLeftSidebarOpen(false); } }, [toolCallId, toolName, args, scriptLabel, setPanel, getEditedScript, setRightPanelOpen, setLeftSidebarOpen]); return ( {toolName === 'sandbox_create' && args && (() => { const hw = String(args.hardware || 'cpu-basic'); const cost = costLabel(hw); return ( Create a remote dev environment on{' '} {hw} {cost && ( {' '}({cost}) )} {!!args.private && ( {' (private)'} )} Creates a temporary HF Space to develop and test scripts before running jobs. Takes 1-2 min to start. ); })()} {toolName === 'hf_jobs' && args && (() => { const hw = String(args.hardware_flavor || 'cpu-basic'); const cost = costLabel(hw); return ( Execute {scriptLabel.replace('Script', 'Job')} on{' '} {hw} {cost && ( {' '}({cost}) )} {!!args.timeout && ( <> for up to {String(args.timeout)} )} {typeof args.script === 'string' && args.script && ( {String(args.script).trim()} Click to view & edit )} ); })()} setFeedback(e.target.value)} variant="outlined" sx={{ '& .MuiOutlinedInput-root': { bgcolor: 'var(--hover-bg)', fontFamily: 'inherit', fontSize: '0.8rem', '& fieldset': { borderColor: 'var(--tool-border)' }, '&:hover fieldset': { borderColor: 'var(--border-hover)' }, '&.Mui-focused fieldset': { borderColor: 'var(--accent-yellow)' }, }, '& .MuiOutlinedInput-input': { color: 'var(--text)', '&::placeholder': { color: 'var(--muted-text)', opacity: 0.7 }, }, }} /> onResolve(toolCallId, false, feedback || 'Rejected by user')} disabled={!feedback} size="small" sx={{ color: 'var(--accent-red)', border: '1px solid var(--tool-border)', borderRadius: '6px', '&:hover': { bgcolor: 'rgba(224,90,79,0.1)', borderColor: 'var(--accent-red)' }, '&.Mui-disabled': { color: 'var(--muted-text)', opacity: 0.3 }, }} > ); } // --------------------------------------------------------------------------- // Main component // --------------------------------------------------------------------------- const EMPTY_AGENTS: Record = {}; export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProps) { const { setPanel, lockPanel, getJobUrl, getEditedScript, setJobStatus, getJobStatus, setToolError, getToolError, setToolRejected, getToolRejected } = useAgentStore(); const researchAgents = useAgentStore(s => { const activeId = s.activeSessionId; return (activeId && s.sessionStates[activeId]?.researchAgents) || EMPTY_AGENTS; }); // Tick once per second while any research agent is running so each card's // elapsed seconds update in real time. const anyResearchRunning = useMemo( () => Object.values(researchAgents).some(a => a.stats.startedAt !== null), [researchAgents], ); useSecondTick(anyResearchRunning); const isProcessing = useAgentStore(s => s.isProcessing); const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); // ── Batch approval state ────────────────────────────────────────── const pendingTools = useMemo( () => tools.filter(t => t.state === 'approval-requested'), [tools], ); const [decisions, setDecisions] = useState>({}); const [isSubmitting, setIsSubmitting] = useState(false); const submittingRef = useRef(false); // Track which toolCallIds we've already submitted so we can detect new approval rounds const submittedIdsRef = useRef>(new Set()); // ── Panel lock state (for auto-follow vs user-selected) ─────────── const [lockedToolId, setLockedToolId] = useState(null); // Reset submission state when new (unseen) pending tools arrive — e.g. second approval round useEffect(() => { if (!isSubmitting || pendingTools.length === 0) return; const hasNewPending = pendingTools.some(t => !submittedIdsRef.current.has(t.toolCallId)); if (hasNewPending) { submittingRef.current = false; setIsSubmitting(false); setDecisions({}); } }, [pendingTools, isSubmitting]); // Clean up stale decisions for tools that are no longer pending useEffect(() => { const pendingIds = new Set(pendingTools.map(t => t.toolCallId)); const decisionIds = Object.keys(decisions); const hasStale = decisionIds.some(id => !pendingIds.has(id)); if (hasStale) { setDecisions(prev => { const cleaned = { ...prev }; for (const id of decisionIds) { if (!pendingIds.has(id)) delete cleaned[id]; } return cleaned; }); } }, [pendingTools, decisions]); // Persist error states when tools error useEffect(() => { for (const tool of tools) { const currentlyHasError = tool.state === 'output-error'; const persistedError = getToolError(tool.toolCallId); // Persist error state if we detect it and haven't already if (currentlyHasError && !persistedError) { setToolError(tool.toolCallId, true); } } }, [tools, setToolError, getToolError]); const { scriptLabelMap, toolDisplayMap } = useMemo(() => { const hfJobs = tools.filter(t => t.toolName === 'hf_jobs' && (t.input as Record)?.script); const scriptMap: Record = {}; const displayMap: Record = {}; for (let i = 0; i < hfJobs.length; i++) { const id = hfJobs[i].toolCallId; if (hfJobs.length > 1) { scriptMap[id] = `Script ${i + 1}`; displayMap[id] = `hf_jobs #${i + 1}`; } else { scriptMap[id] = 'Script'; displayMap[id] = 'hf_jobs'; } } // Pretty name for research tool for (const t of tools) { if (t.toolName === 'research') { displayMap[t.toolCallId] = 'research'; } } return { scriptLabelMap: scriptMap, toolDisplayMap: displayMap }; }, [tools]); // ── Send all decisions as a single batch ────────────────────────── const sendBatch = useCallback( async (batch: Record) => { if (submittingRef.current) return; submittingRef.current = true; setIsSubmitting(true); const approvals = Object.entries(batch).map(([toolCallId, d]) => { const editedScript = d.approved ? (getEditedScript(toolCallId) ?? null) : null; if (editedScript) { logger.log(`Sending edited script for ${toolCallId} (${editedScript.length} chars)`); } // Mark tool as rejected if not approved if (!d.approved) { setToolRejected(toolCallId, true); } return { tool_call_id: toolCallId, approved: d.approved, feedback: d.approved ? null : (d.feedback || 'Rejected by user'), edited_script: editedScript, }; }); const ok = await approveTools(approvals); if (ok) { // Track which tool IDs were submitted so we can detect new approval rounds for (const a of approvals) submittedIdsRef.current.add(a.tool_call_id); lockPanel(); } else { logger.error('Batch approval failed'); submittingRef.current = false; setIsSubmitting(false); } }, [approveTools, lockPanel, getEditedScript, setToolRejected], ); const handleApproveAll = useCallback(() => { const batch: Record = {}; for (const t of pendingTools) batch[t.toolCallId] = { approved: true }; sendBatch(batch); }, [pendingTools, sendBatch]); const handleRejectAll = useCallback(() => { const batch: Record = {}; for (const t of pendingTools) batch[t.toolCallId] = { approved: false }; sendBatch(batch); }, [pendingTools, sendBatch]); const handleIndividualDecision = useCallback( (toolCallId: string, approved: boolean, feedback?: string) => { setDecisions(prev => { const next = { ...prev, [toolCallId]: { approved, feedback } }; if (pendingTools.every(t => next[t.toolCallId])) { queueMicrotask(() => sendBatch(next)); } return next; }); }, [pendingTools, sendBatch], ); const undoDecision = useCallback((toolCallId: string) => { setDecisions(prev => { const next = { ...prev }; delete next[toolCallId]; return next; }); }, []); // ── Show tool panel (shared logic) ──────────────────────────────── const showToolPanel = useCallback( (tool: DynamicToolPart) => { const args = tool.input as Record | undefined; const displayName = toolDisplayMap[tool.toolCallId] || tool.toolName; if (tool.toolName === 'hf_jobs' && args?.script) { const jobOutput = tool.output ?? (tool.state === 'output-error' ? (tool as Record).errorText : undefined); const hasOutput = (tool.state === 'output-available' || tool.state === 'output-error') && jobOutput; const scriptContent = getEditedScript(tool.toolCallId) || String(args.script); setPanel( { title: displayName, script: { content: scriptContent, language: 'python' }, ...(hasOutput ? { output: { content: String(jobOutput), language: 'markdown' } } : {}), parameters: { tool_call_id: tool.toolCallId }, }, hasOutput ? 'output' : 'script', ); setRightPanelOpen(true); setLeftSidebarOpen(false); return; } const inputSection = args ? { content: JSON.stringify(args, null, 2), language: 'json' } : undefined; const outputText = tool.output ?? (tool.state === 'output-error' ? (tool as Record).errorText : undefined); const hasCompleted = tool.state === 'output-available' || tool.state === 'output-error' || tool.state === 'output-denied'; if (outputText) { // Tool has output - show it (regardless of state) let language = 'text'; const content = String(outputText); if (content.trim().startsWith('{') || content.trim().startsWith('[')) language = 'json'; else if (content.includes('```')) language = 'markdown'; setPanel({ title: displayName, output: { content, language }, input: inputSection }, 'output'); setRightPanelOpen(true); } else if (tool.state === 'output-error') { const content = `Tool \`${tool.toolName}\` returned an error with no output message.`; setPanel({ title: displayName, output: { content, language: 'markdown' }, input: inputSection }, 'output'); setRightPanelOpen(true); } else if (hasCompleted && args) { // Tool completed but has no output - show input as fallback setPanel({ title: displayName, output: { content: JSON.stringify(args, null, 2), language: 'json' }, input: inputSection }, 'output'); setRightPanelOpen(true); } else if (args) { const runningMessages = [ 'Crunching numbers and herding tensors...', 'Teaching the model some new tricks...', 'Consulting the GPU oracle...', 'Wrangling data into submission...', 'Brewing a fresh batch of predictions...', 'Negotiating with the transformer heads...', 'Polishing the attention weights...', 'Aligning the embedding stars...', ]; const funMsg = runningMessages[Math.floor(Math.random() * runningMessages.length)]; setPanel({ title: displayName, output: { content: funMsg, language: 'text' }, input: inputSection }, 'output'); setRightPanelOpen(true); } }, [toolDisplayMap, setPanel, getEditedScript, setRightPanelOpen, setLeftSidebarOpen], ); // ── Panel click handler ─────────────────────────────────────────── const handleClick = useCallback( (tool: DynamicToolPart) => { // Toggle lock: if clicking the same tool that's already locked, unlock it if (lockedToolId === tool.toolCallId) { setLockedToolId(null); return; } // Lock this tool setLockedToolId(tool.toolCallId); // Show the panel showToolPanel(tool); }, [lockedToolId, showToolPanel], ); // ── Auto-follow currently active tool when not locked ───────────── const activeToolIdRef = useRef(null); useEffect(() => { if (lockedToolId !== null) return; // User has locked a tool, don't auto-follow // Find the currently running tool (latest tool that's in progress) const runningTool = tools.slice().reverse().find(t => t.state === 'input-available' || t.state === 'input-streaming' || t.state === 'approval-responded' ); if (runningTool) { // Track this as the active tool and show its panel activeToolIdRef.current = runningTool.toolCallId; showToolPanel(runningTool); } else if (activeToolIdRef.current) { // No running tool, but we were following one - check if it completed const completedTool = tools.find(t => t.toolCallId === activeToolIdRef.current); if (completedTool && (completedTool.state === 'output-available' || completedTool.state === 'output-error')) { // The tool we were following has completed - update its panel showToolPanel(completedTool); } } }, [tools, lockedToolId, showToolPanel]); // ── Parse hf_jobs metadata from output ──────────────────────────── function parseJobMeta(output: unknown): { jobUrl?: string; jobStatus?: string } { if (typeof output !== 'string') return {}; const urlMatch = output.match(/\*\*View at:\*\*\s*(https:\/\/[^\s\n]+)/); const statusMatch = output.match(/\*\*Final Status:\*\*\s*([^\n]+)/); return { jobUrl: urlMatch?.[1], jobStatus: statusMatch?.[1]?.trim(), }; } // ── Render ──────────────────────────────────────────────────────── const decidedCount = pendingTools.filter(t => decisions[t.toolCallId]).length; return ( {/* Batch approval header — hidden once user starts deciding individually */} {pendingTools.length > 1 && !isSubmitting && decidedCount === 0 && ( {`${pendingTools.length} tool${pendingTools.length > 1 ? 's' : ''} pending`} )} {/* Tool list */} }> {tools.map((tool) => { const state = tool.state; const isPending = state === 'approval-requested'; const clickable = state === 'output-available' || state === 'output-error' || !!tool.input || (!isProcessing && (state === 'input-available' || state === 'input-streaming')); const localDecision = decisions[tool.toolCallId]; const cancelled = isCancelledTool(tool); const currentlyHasError = state === 'output-error'; const persistedError = getToolError(tool.toolCallId); const persistedRejection = getToolRejected(tool.toolCallId); // Stale in-progress tools after page reload: treat as completed const stale = !isProcessing && (state === 'input-available' || state === 'input-streaming'); const displayState = stale ? 'output-available' : isPending && localDecision ? (localDecision.approved ? 'input-available' : 'output-denied') : state; const isRejected = displayState === 'output-denied' || persistedRejection; const hasError = (persistedError || currentlyHasError) && !isRejected; const label = cancelled ? 'cancelled' : isRejected ? 'rejected' : hasError ? 'error' : statusLabel(displayState as ToolPartState); // Parse job metadata from hf_jobs output and store const jobUrlFromStore = tool.toolName === 'hf_jobs' ? getJobUrl(tool.toolCallId) : undefined; const jobStatusFromStore = tool.toolName === 'hf_jobs' ? getJobStatus(tool.toolCallId) : undefined; const jobMetaFromOutput = tool.toolName === 'hf_jobs' && (tool.output || (tool as Record).errorText) ? parseJobMeta(tool.output ?? (tool as Record).errorText) : {}; // Store job status if we just parsed it and don't have it stored yet if (tool.toolName === 'hf_jobs' && jobMetaFromOutput.jobStatus && !jobStatusFromStore) { setJobStatus(tool.toolCallId, jobMetaFromOutput.jobStatus); } // Combine job URL and status from store (persisted) with output metadata (freshly parsed) // Prefer stored values to ensure they persist across renders const jobMeta = { jobUrl: jobUrlFromStore || jobMetaFromOutput.jobUrl, jobStatus: jobStatusFromStore || jobMetaFromOutput.jobStatus, }; return ( {/* Main tool row */} !isPending && handleClick(tool)} sx={{ px: 1.5, py: 1, cursor: isPending ? 'default' : clickable ? 'pointer' : 'default', transition: 'background-color 0.15s', bgcolor: lockedToolId === tool.toolCallId ? 'var(--hover-bg)' : 'transparent', borderLeft: lockedToolId === tool.toolCallId ? '3px solid var(--accent-yellow)' : '3px solid transparent', '&:hover': clickable && !isPending ? { bgcolor: 'var(--hover-bg)' } : {}, }} > {toolDisplayMap[tool.toolCallId] || tool.toolName} {/* Status chip (non hf_jobs, or hf_jobs without final status) */} {(() => { // Research tool: override chip label with this card's agent stats const agentState: ResearchAgentState | undefined = tool.toolName === 'research' ? researchAgents[tool.toolCallId] : undefined; const researchDone = cancelled || state === 'output-available' || state === 'output-error' || state === 'output-denied'; const liveElapsed = agentState ? computeElapsed(agentState.stats.startedAt) : null; const researchLabel = tool.toolName === 'research' && agentState ? (researchDone && agentState.stats.finalElapsed !== null ? researchChipLabel({ ...agentState.stats, startedAt: null }, null) : researchChipLabel(agentState.stats, liveElapsed)) : null; const chipLabel = researchLabel || label; if (!chipLabel || (tool.toolName === 'hf_jobs' && jobMeta.jobStatus)) return null; return ( ); })()} {/* HF Jobs: final status chip from job metadata */} {tool.toolName === 'hf_jobs' && jobMeta.jobStatus && ( )} {/* View on HF link — single place, shown whenever URL is available */} {tool.toolName === 'hf_jobs' && jobMeta.jobUrl && ( e.stopPropagation()} sx={{ display: 'inline-flex', alignItems: 'center', gap: 0.5, color: 'var(--accent-yellow)', fontSize: '0.68rem', textDecoration: 'none', ml: 0.5, '&:hover': { textDecoration: 'underline' }, }} > View on HF )} {clickable && !isPending && ( )} {/* Research sub-agent rolling steps (visible only while running) */} {tool.toolName === 'research' && !cancelled && state !== 'output-available' && state !== 'output-error' && state !== 'output-denied' && researchAgents[tool.toolCallId] && ( )} {/* Per-tool approval: undecided */} {isPending && !localDecision && !isSubmitting && ( )} {/* Per-tool approval: locally decided (undo available) */} {isPending && localDecision && !isSubmitting && ( {localDecision.approved ? 'Marked for approval' : `Marked for rejection${localDecision.feedback ? `: ${localDecision.feedback}` : ''}`} )} ); })} ); } ================================================ FILE: frontend/src/components/Chat/UserMessage.tsx ================================================ import { useState, useRef, useEffect } from 'react'; import { Box, Stack, Typography, IconButton, Tooltip, TextField } from '@mui/material'; import CloseIcon from '@mui/icons-material/Close'; import EditIcon from '@mui/icons-material/Edit'; import CheckIcon from '@mui/icons-material/Check'; import type { UIMessage } from 'ai'; import type { MessageMeta } from '@/types/agent'; interface UserMessageProps { message: UIMessage; isLastTurn?: boolean; onUndoTurn?: () => void; onEditAndRegenerate?: (messageId: string, newText: string) => void | Promise; isProcessing?: boolean; } function extractText(message: UIMessage): string { return message.parts .filter((p): p is Extract => p.type === 'text') .map(p => p.text) .join(''); } export default function UserMessage({ message, isLastTurn = false, onUndoTurn, onEditAndRegenerate, isProcessing = false, }: UserMessageProps) { const showUndo = isLastTurn && !isProcessing && !!onUndoTurn; const showEdit = !isProcessing && !!onEditAndRegenerate; const text = extractText(message); const meta = message.metadata as MessageMeta | undefined; const timeStr = meta?.createdAt ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) : null; const [isEditing, setIsEditing] = useState(false); const [editText, setEditText] = useState(text); const inputRef = useRef(null); useEffect(() => { if (isEditing && inputRef.current) { inputRef.current.focus(); inputRef.current.selectionStart = inputRef.current.value.length; } }, [isEditing]); const handleStartEdit = () => { setEditText(text); setIsEditing(true); }; const handleConfirmEdit = () => { const trimmed = editText.trim(); if (!trimmed || trimmed === text) { setIsEditing(false); return; } setIsEditing(false); onEditAndRegenerate?.(message.id, trimmed); }; const handleCancelEdit = () => { setIsEditing(false); setEditText(text); }; const handleKeyDown = (e: React.KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); handleConfirmEdit(); } else if (e.key === 'Escape') { handleCancelEdit(); } }; return ( {!isEditing && (showUndo || showEdit) && ( {showEdit && ( )} {showUndo && ( )} )} {isEditing ? ( setEditText(e.target.value)} onKeyDown={handleKeyDown} variant="outlined" size="small" sx={{ '& .MuiOutlinedInput-root': { fontFamily: 'inherit', fontSize: '0.925rem', lineHeight: 1.65, color: 'var(--text)', '& fieldset': { borderColor: 'var(--accent-yellow)', borderWidth: 1.5 }, '&:hover fieldset': { borderColor: 'var(--accent-yellow)' }, '&.Mui-focused fieldset': { borderColor: 'var(--accent-yellow)' }, }, }} /> ) : ( {text} )} {timeStr && !isEditing && ( {timeStr} )} ); } ================================================ FILE: frontend/src/components/ClaudeCapDialog.tsx ================================================ import { Box, Button, Dialog, DialogActions, DialogContent, DialogContentText, DialogTitle, Typography, } from '@mui/material'; import type { PlanTier } from '@/hooks/useUserQuota'; const HF_PRICING_URL = 'https://huggingface.co/pricing'; const PRO_CAP = 20; interface ClaudeCapDialogProps { open: boolean; plan: PlanTier; cap: number; onClose: () => void; onUseFreeModel: () => void; } export default function ClaudeCapDialog({ open, plan, cap, onClose, onUseFreeModel, }: ClaudeCapDialogProps) { // plan not surfaced in copy right now — Pro users see the same dialog and // can upgrade their org if they're also capped. void plan; return ( You've hit your Opus limit Opus costs an arm and a leg, so we unfortunately have to cap you at {cap}{' '} {cap === 1 ? 'session' : 'sessions'} a day. Give Kimi, MiniMax, or GLM a spin — they are genuinely good and we use them all the time. HF Pro ($9/mo) — more Opus, more everything {PRO_CAP} Opus sessions/day here, 20× HF Inference credits, ZeroGPU access, and priority on Spaces hardware. ); } ================================================ FILE: frontend/src/components/CodePanel/CodePanel.tsx ================================================ import { useRef, useEffect, useMemo, useState, useCallback } from 'react'; import { Box, Stack, Typography, IconButton, Button, Tooltip } from '@mui/material'; import CloseIcon from '@mui/icons-material/Close'; import RadioButtonUncheckedIcon from '@mui/icons-material/RadioButtonUnchecked'; import CheckCircleIcon from '@mui/icons-material/CheckCircle'; import PlayCircleOutlineIcon from '@mui/icons-material/PlayCircleOutline'; import CodeIcon from '@mui/icons-material/Code'; import ArticleIcon from '@mui/icons-material/Article'; import EditIcon from '@mui/icons-material/Edit'; import UndoIcon from '@mui/icons-material/Undo'; import ContentCopyIcon from '@mui/icons-material/ContentCopy'; import CheckIcon from '@mui/icons-material/Check'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; import { vscDarkPlus, vs } from 'react-syntax-highlighter/dist/esm/styles/prism'; import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; import { useAgentStore } from '@/store/agentStore'; import { useLayoutStore } from '@/store/layoutStore'; import { processLogs } from '@/utils/logProcessor'; import type { PanelView } from '@/store/agentStore'; // ── Helpers ────────────────────────────────────────────────────── function PlanStatusIcon({ status }: { status: string }) { if (status === 'completed') return ; if (status === 'in_progress') return ; return ; } // ── Markdown styles (adapts via CSS vars) ──────────────────────── const markdownSx = { color: 'var(--text)', fontSize: '13px', lineHeight: 1.6, '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } }, '& pre': { bgcolor: 'var(--code-bg)', p: 1.5, borderRadius: 1, overflow: 'auto', fontSize: '12px', border: '1px solid var(--tool-border)', }, '& code': { bgcolor: 'var(--hover-bg)', px: 0.5, py: 0.25, borderRadius: 0.5, fontSize: '12px', fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', }, '& pre code': { bgcolor: 'transparent', p: 0 }, '& a': { color: 'var(--accent-yellow)', textDecoration: 'none', '&:hover': { textDecoration: 'underline' }, }, '& ul, & ol': { pl: 2.5, my: 1 }, '& li': { mb: 0.5 }, '& table': { borderCollapse: 'collapse', width: '100%', my: 2, fontSize: '12px', fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', }, '& th': { borderBottom: '2px solid var(--border-hover)', textAlign: 'left', p: 1, fontWeight: 600, }, '& td': { borderBottom: '1px solid var(--tool-border)', p: 1, }, '& h1, & h2, & h3, & h4': { mt: 2, mb: 1, fontWeight: 600 }, '& h1': { fontSize: '1.25rem' }, '& h2': { fontSize: '1.1rem' }, '& h3': { fontSize: '1rem' }, '& blockquote': { borderLeft: '3px solid var(--accent-yellow)', pl: 2, ml: 0, color: 'var(--muted-text)', }, } as const; // ── View toggle button ────────────────────────────────────────── function ViewToggle({ view, icon, label, isActive, onClick }: { view: PanelView; icon: React.ReactNode; label: string; isActive: boolean; onClick: (v: PanelView) => void; }) { return ( onClick(view)} sx={{ display: 'flex', alignItems: 'center', gap: 0.5, px: 1.5, py: 0.75, borderRadius: 1, cursor: 'pointer', fontSize: '0.7rem', fontWeight: 600, textTransform: 'uppercase', letterSpacing: '0.05em', whiteSpace: 'nowrap', color: isActive ? 'var(--text)' : 'var(--muted-text)', bgcolor: isActive ? 'var(--tab-active-bg)' : 'transparent', border: '1px solid', borderColor: isActive ? 'var(--tab-active-border)' : 'transparent', transition: 'all 0.15s ease', '&:hover': { bgcolor: 'var(--tab-hover-bg)' }, }} > {icon} {label} ); } // ── Component ──────────────────────────────────────────────────── export default function CodePanel() { const { panelData, panelView, panelEditable, setPanelView, updatePanelScript, setEditedScript, plan } = useAgentStore(); const { setRightPanelOpen, themeMode } = useLayoutStore(); const scrollRef = useRef(null); const textareaRef = useRef(null); const [isEditing, setIsEditing] = useState(false); const [editedContent, setEditedContent] = useState(''); const [originalContent, setOriginalContent] = useState(''); const [copied, setCopied] = useState(false); const [showInput, setShowInput] = useState(false); const isDark = themeMode === 'dark'; const syntaxTheme = isDark ? vscDarkPlus : vs; const activeSection = panelView === 'script' ? panelData?.script : panelData?.output; const hasScript = !!panelData?.script; const hasOutput = !!panelData?.output; const hasBothViews = hasScript && hasOutput; const isEditableScript = panelView === 'script' && panelEditable; const hasUnsavedChanges = isEditing && editedContent !== originalContent; // Reset input toggle when panel data changes useEffect(() => { setShowInput(false); }, [panelData]); // Sync edited content when panel data changes useEffect(() => { if (panelData?.script?.content && panelView === 'script' && panelEditable) { setOriginalContent(panelData.script.content); if (!isEditing) { setEditedContent(panelData.script.content); } } }, [panelData?.script?.content, panelView, panelEditable, isEditing]); // Exit editing when switching away from script view or losing editable useEffect(() => { if (!isEditableScript && isEditing) { setIsEditing(false); } }, [isEditableScript, isEditing]); const handleStartEdit = useCallback(() => { if (panelData?.script?.content) { setEditedContent(panelData.script.content); setOriginalContent(panelData.script.content); setIsEditing(true); setTimeout(() => textareaRef.current?.focus(), 0); } }, [panelData?.script?.content]); const handleCancelEdit = useCallback(() => { setEditedContent(originalContent); setIsEditing(false); }, [originalContent]); const handleSaveEdit = useCallback(() => { if (editedContent !== originalContent) { updatePanelScript(editedContent); const toolCallId = panelData?.parameters?.tool_call_id as string | undefined; if (toolCallId) { setEditedScript(toolCallId, editedContent); } setOriginalContent(editedContent); } setIsEditing(false); }, [panelData?.parameters?.tool_call_id, editedContent, originalContent, updatePanelScript, setEditedScript]); const handleCopy = useCallback(async () => { const contentToCopy = isEditing ? editedContent : (activeSection?.content || ''); if (contentToCopy) { try { await navigator.clipboard.writeText(contentToCopy); setCopied(true); setTimeout(() => setCopied(false), 2000); } catch (err) { console.error('Failed to copy:', err); } } }, [isEditing, editedContent, activeSection?.content]); const visibleSection = (showInput && panelData?.input) ? panelData.input : activeSection; const displayContent = useMemo(() => { if (!visibleSection?.content) return ''; if (!visibleSection.language || visibleSection.language === 'text') { return processLogs(visibleSection.content); } return visibleSection.content; }, [visibleSection?.content, visibleSection?.language]); // Auto-scroll only for live log streaming, not when opening panel const hasAutoScrolled = useRef(false); useEffect(() => { hasAutoScrolled.current = false; }, [panelData]); useEffect(() => { if (scrollRef.current && panelView === 'output' && hasAutoScrolled.current) { scrollRef.current.scrollTop = scrollRef.current.scrollHeight; } hasAutoScrolled.current = true; }, [displayContent, panelView]); // ── Syntax-highlighted code block (DRY) ──────────────────────── const renderSyntaxBlock = (language: string) => ( {displayContent} ); // ── Content renderer ─────────────────────────────────────────── const renderContent = () => { if (!visibleSection?.content) { return ( NO CONTENT TO DISPLAY ); } if (!showInput && isEditing && isEditableScript) { return ( {editedContent || ' '}